[2211.17192] Fast Inference from Transformers via Speculative Decoding

 

Fast Inference from Transformers via Speculative Decoding

Inference from large autoregressive models like Transformers is slow - decoding K tokens takes K serial runs of the model. In this work we introduce speculative decoding - an algorithm to sample from autoregressive models faster without any changes to the

arxiv.org

 

대학원 합격 후 이 논문 발표 할 일이 생겼는데 이후 검색해 보니 이 글 쓴 시점에서는 처음부터 끝까지 한국어로 잘 정리된 글이 없는 것 같아 다른 한국인 분들이 이 논문을 읽을 때 편하게 읽으시라고 정리해서 올립니다. 다른 개인적인 목적을 위해서도 올립니다. 아무쪼록 논문 읽는데 도움 되시길 바랍니다.

Abstract

 

이 논문은 출력에 변화 없이  Autoregressive(자기 회귀) 모델에서 더 빠르게 샘플링할 수 있는 Speculative Decoding (추측 디코딩) 알고리즘 소개합니다.

 

1. Introduction

 

이 논문은 처음에 대형 자동회귀 모델(대형 트랜스포머, LLM)의 추론 속도를 높이기 위한 여러 접근 방식을 소개합니다.

  1. 모든 입력에 대해 동일하게 추론 비용을 줄이는 것을 목표로 하는 방식
  2. 더 쉬운 추론 단계에서는 더 적은 계산 리소스를 사용하는 것을 목표로 하는 방식

그러나 이러한 접근 방식들은 일반적으로 모델 아키텍처를 변경하거나, 학습 절차를 수정하고, 모델을 재학습해야 하며, 동일한 출력을 유지하지 못하는 경우가 많기에 안 좋다고 주장하며 동시에 이러한 이전 연구 논문에서 얻은 통찰을 말해줍니다.

  1. 추론 중 일부는 "어렵고", 일부는 "쉽습니다. 
  2. LLM의 병목 현상은 산술연산이 아닌 메모리 대역폭과 통신에 있습니다. 따라서, 추가적인 계산 리소스가 사용 가능할 수 있습니다.

부연설명을 해보자면 1. 에서말하는 일부 추론이 쉽다는 뜻은 언어모델이 "The car is blue"를 예측할 때 "The car"이란 앞단 토큰은 쉬운 추론. 뒷단의 "is blue"는 어려운 추론이라고 말하고 있습니다.. 어찌 보면 당연합니다.

그리고 2. 에서 말하는 통신은 여러 개의 GPU를 사용하거나 분산컴퓨팅을 하게 되면 각 GPU/컴퓨터 간 통신을 해야 하는데 그것을 뜻합니다. 보통의 대형 모델은 여러 GPU나 여러 컴퓨터를 이용해 연산합니다.

 

두 개의 통찰은 다음과 같은 직관으로 연결됩니다.

  1. 적응적인 계산량을 사용하면 더 빠른 추론이 가능합니다.
  2. 남은 컴퓨팅 리소스를 사용하여 병렬 처리를 늘릴 수 있습니다

또한 "이 두 개의 직관의 실현을 달성할 수 있는 것 은 Speculative execution 기술이다."이다 말하며 Speculative execution을 기반한 Speculative Decoding을 소개합니다.

 

여기서 Speculative execution 이란?

다음과 같은 코드가 있고 x > 10 인지 판단하기 위해 x 값을 읽는 데 많은 시간이 필요하다고 가정 (병목현상), 이때 CPU를 그냥 내버려 두는 것이 아니라 x 값을 읽을 때 y = 3 을 먼저 실행시켜 CPU의 자원이 놀지 않게 하는 것이 Speculative execution 기술입니다.

여기선 y = 3 이란 코드 단 한줄로 예시를 들었지만 극단적으로 if 문 뒤에 엄청나게 많은 코드들이 존재하고 x 값이 엄청나게 큰 리스트라 읽는 데에 엄청 많은 시간이 필요하다고 생각하면 됩니다.

그래서 이걸 기반한 Speculative Decoding 은 뭐냐?

그림 1. 각 선은 알고리즘의 한 반복을 나타냅니다. 초록색 토큰은 근사 모델(6M 파라미터를 가진 GPT-like Transformer 디코더)이 제안한 토큰으로, 목표 모델(97M 파라미터를 가진 GPT-like Transformer 디코더)이 받아들인 것입니다. 반면 빨간색파란색 토큰은 거부된 제안과 그에 대한 수정 사항을 나타냅니다. 예를 들어 첫 번째 줄에서 타겟 모델은 한 번만 실행되었고, 5개의 토큰이 생성되었습니다.

 

다음 그림과 같은데 첫번째줄을 설명하자면

"[START]" 토큰 들어오고 Autoregressive(자기 회귀) 기법으로 근사 모델을 돌려 "[START] japan' s benchmark bond"까지 토큰을 5개 생성합니다. 근사 모델은 매우 작은 Transformer 모델이기 때문에 매우 빠르게 5개의 토큰 생성합니다.

그리고 매우 큰 모델(목표 모델)에게

  1. [START]
  2. [START] japan
  3. [START] japan'
  4. [START] japan' s
  5. [START] japan' benchmark
  6. [START] japan' benchmark bond

다음과 같은 6개의 문장을 병렬적으로 넣습니다. 그러면 총 6개의 분포가 나옵니다.

 

첫번째 "[START]" 를 목표 모델 input 으로 넣고 나온 output  분포로 2번째  토큰이 " japan " 이 맞는지 검사합니다.

두번째 "[START] japan" 를 목표 모델 input 으로 넣고 나온 output 분포로 3번째 토큰이 " ' " 이 맞는지 검사합니다.

 

도중에 검사가 "틀린 것" 이 있으면 그 이후의 모든 토큰은 "틀린 것" 으로 간주합니다.

그림 예시로는 5번째 input 으로 얻은 output 분포로 "bond" 토큰이 틀렸다고 나왔고 이후 토큰이 없으니 output 분포의 정답 토큰인 "n" 토큰을 생성합니다.

임의 예시로 4번째 input 으로 얻은 output 분포로 "benchmark" 토큰이 틀렸다고 나오면 " bond" 토큰도 자동적으로 "틀린 것" 으로 간주하고  output 분포의 정답 토큰을 생성합니다.

 

틀린지 맞는지 검사 는 "Speculative Sampling" 으로 검사함 로직은 다음 챕터에서 설명하겠습니다.

여기서 알아야할 매우 중요한 3가지는 

 

  1. 이미 어떤 토큰을 검사할지 다 알고있기 때문에 병렬적으로 검사 가능. ( 목표모델이 Autoregressive 하게 돌아가지않고 병렬적으로 돌아감)
  2. 모델 아키텍처를 변경하지 않고 훈련 절차를 변경하거나 모델을 재훈련할 필요도 없으며 모델 출력 분포를 변경하지 않음
  3. 초록색 토큰 즉, 앞단의 토큰들은 "쉬운 작업" 이기에 작은 근사모델로도 충분히 예측 가능함

 

 

2. Speculative Decoding

 

Speculative Decoding 알고리즘 중 Speculative Sampling 알고리즘이 포함되어있고 이 Sampling 이 중요하기에 먼저 설명하겠습니다. 수식 정의도 하겠습니다.

정의

 

$p(x)$, $q(x)$ 는 각 모델에 $t$ 시점 이전의 모든 토큰이 input 일때 $t$ 시점 토큰의 확률분포 입니다. 이 정의는 계속 쓰이니 기억해두시길 바랍니다.

 

Speculative Sampling 알고리즘
예시

 

예시를 차근차근 보시면 충분히 이해하실것 같습니다. p와 q의 값이 차이가 많을수록 ( 근사모델이 잘 예측 못할수록) 토큰을 거절 할 확률은 높아지는것을 볼수있습니다. 거절 됐을때 변형된 분포에서 토큰을 생성(sampling) 하게 되는데 그 분포를 보시면 이미 거절한 토큰 "He" 는 0% 으로 만들어 다시 나오지 못하게 합니다.

 

토큰 거절의 로직이 저런 이유는 output 분포를 유지시키기 위해서 저런 방식을 쓰고있습니다. 논문 맨 뒤 11페이지를 보시면 설명하길  Rejection Sampling 을 기반한것 같습니다.

 

Speculative Decoding 알고리즘

 

이제 논문의 Speculative Decoding 알고리즘 을 보시면 이제 이해가 쉽게 되실겁니다.

  1. 근사 모델로 $ \gamma $ 만큼  Autoregressive 하게 토큰 및 분포 생성
  2. 생성된 토큰들을 prefix(이전 토큰) 들과 합해 각각 병렬적으로 목표 모델에 넣어 분포 생성
  3. 근사 모델 분포와 목표 모델 분포로 Speculative Sampling 을 돌려 토큰 검사 및 토큰 생성

중간의 Determine the number of accepted guesses n. 부분은 오른쪽 예시를 보시면 알겠지만 Speculative Sampling 알고리즘 중 일부를 수학적으로 표현한 것 입니다.

$  r_i >  \frac{p(x)}{q(x)} $ 부분이 이해가 안되실수있는데

 

$r_{i} $는 균등분포로 0~1 사이의 값을 뽑는거니 수학적으로 랜덤 선택을 뜻 합니다. 또한 Speculative Sampling 은 $ 1-\frac{p(x)}{q(x)} $ 확률로 토큰 거절을 합니다. 이것을 수학적으로 표현하면 $ r_i \leq 1- \frac{p(x)}{q(x)} $ 입니다.

 

$ r_i \leq 1- \frac{p(x)}{q(x)} $ 과  $  r_i >  \frac{p(x)}{q(x)} $ 은 확률적으로 선택하는 식으로써 확률이 같습니다.

 

 

3. Analysis

이 챕터에서는 두 모델이 주어졌을때 가장 효율적으로 Speculative Decoding 을 쓰기위해 최적의 $\gamma$ 값을 찾아가는 수학적 여정을 다룹니다. 따라서 다양한 기대값을 정의하고 Speculative Decoding 사용할때와 사용하지않았을때의 연산량 및 실행시간 비율을 구합니다.

 

정의
예시

 

$\beta$ 는 "수락률" $\alpha$ 는 "수락률의 기대값" 이라고 정의하고있습니다.

"생성될 토큰의 기대값" 인 $E(\#generated \ tokens)$는 제한된 기하 분포(Capped Geometric Distribution) 의 기대값 에서 유래됬습니다. 이게 뭔소리나면

 

Speculative Decoding 알고리즘은 

  1. 최대 𝛾개의 토큰을 병렬로 예측합니다.
  2. 모든 토큰이 실패 없이 승인되면 최대 𝛾+1개의 토큰을 생성합니다.
  3. 첫 번째 실패(거부)가 발생할 때까지 성공적으로 승인된 토큰의 개수를 셉니다.

입니다. 이를 수식으로 표현하면 다음과 같고

이를 정리하면

 

우리가 보던 식이 나옵니다.

 

3.3. Walltime Improvement

 

 

이 챕터에서는 Speculative Decoding 사용할때와 사용하지않았을때의 실행시간 개선 비율을 구합니다.

 

 

정의
정리

 

실행시간 개선 비율을 구하는 수식은 위와 같은데 이를 풀면 아래와 같습니다.

재해석

결국에는 "(Speculative decoding 알고리즘 시간) 대비 (원본 디코딩 알고리즘 시간) 이 몇배인가?" 를 물어보는 수식입니다. 1배 초과이면 Speculative decoding을 사용함으로써 개선이 된거고 1배 미만이면 Speculative decoding을 사용함으로써 오히려 악화 된것이라고 볼수있습니다.

 

증명

 

논문에서는 위와같은 증명을 써놨는데 저가 느끼기에 설명이 와닿지 않습니다. 따라서 2~3일 고민끝에 저만의 방식으로 재해석 한것이니 참고하세요.

 

3.4. Number of Arithmetic Operations

 

 

이 챕터에서는 Speculative Decoding 사용할때와 사용하지않았을때의 산술연산 비율을 구합니다.

 

 

정의
정리
재해석

 

정리 3.11 식 또한 풀면 "(원본 디코딩 산술연산) 대비 (Speculative decoding 산술연산) 이 몇배인가?" 를 물어보는 수식입니다. 3.8 정리와 반대로 1배 초과이면 Speculative decoding을 사용함으로써 연산량이 늘어난것이고 1배 미만이면 Speculative decoding을 사용함으로써 연산량이 줄어들었다는 의미입니다.

 

증명

똑같습니다. 논문에서는 위와같이 증명이 써져있습니다. 다만 잘 와닿지 않습니다. 따라서 2~3일 고민끝에 저만의 방식으로 재해석 한것이니 참고하세요.

3.5. Choosing $\gamma$

 

 

 

결국에는 목표 모델 과 근사 모델이 주어졌을때 $T_p$, $T_q$, $\hat{T_p}$, $\hat{T_q}$ 을 구하고 여러 output 분포 표본을 이용해 $\alpha$ 값을 구한뒤 3.8 과 3.11 수식으로 최적의 $\gamma$ ( optimal $\gamma$ )를 구할수있습니다.

 

 

또한 위와같은 실행 시간 타임라인 표현한 그림 예시를 보여줍니다. 다만 저가 임의로 아래와같이 좀 더 정확하게 표현해봤습니다.

 

M_p decoder 의 연산량은 그림에 표현된 크기와 상관없이 동일

 

이 챕터 마지막에서는 실행 중 $\gamma$ 값을 동적으로 조정함으로써 성능을 더욱 최적화할 수 있다고 말합니다. 이게 뭔의미냐면

$\alpha$ (수락 확률 기대값)는
전체적으로 근사 모델 $M_{q}$가 목표 모델 $M_{p}$에 의해 수락될 확률의 평균값입니다.
특정한 문맥에 상관없이, 모델이 얼마나 잘 추측하는지를 나타내는 전반적인 평가 기준입니다.
즉, 고정된 값으로 간주될 수 있습니다. 이 논문에서는 $\alpha$ 을 기반으로 $\gamma$ 를 정했습니다. 
$\beta$ (수락 확률)는
특정 입력 문맥에서 실행 도중 측정되는 실제 수락률입니다.
문맥에 따라 다를 수 있으며, 문장의 초반부와 후반부에서 다르게 나타날 수 있습니다.
즉, 동적으로 변화하는 값입니다.

실행 중에 $\gamma$ 값을 조정한다는 의미는 모델이 실행 중에 현재 문맥에 대한 실제 수락률 $\beta$ 값을 모니터링하고,이 값을 기반으로 $\gamma$ 값을 동적으로 조정하여 성능을 더 최적화 할수있다고 하며 추가연구 사항이라고 말합니다.


예제
초기: 모델의 수락률이 높음 ($\beta$≈0.9) 모델이 정확하므로 한 번에 더 많은 토큰을 병렬로 예측 가능 → $\gamma$ 증가
중반: 문맥이 복잡해져서 수락률이 떨어짐 ($\beta$≈0.6) 불확실성이 증가하므로 병렬 샘플링을 줄이고 신중하게 진행 → $\gamma$ 감소
후반: 모델의 수락률이 다시 안정됨 ($\beta$≈0.8) 병렬 샘플링을 다시 증가시켜 속도를 최적화 → $\gamma$ 재조정
이러한 방식으로 실행 중 동적으로 조정하는 것이 고정된 $\gamma$ 값보다 더 나은 성능을 제공할 수 있습니다.

 

4. Experiments

결국에는 trade-off  형식으로 성능 향상이 되었다는 것을 확인가능하네요.

 

2025-02-02 경 논문 읽고 난 뒤 글쓴이가 제안한 추가 아이디어

 

 
 

$ \ alpha $$

 

+ Recent posts