[2305.09781] SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification

 

SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification

This paper introduces SpecInfer, a system that accelerates generative large language model (LLM) serving with tree-based speculative inference and verification. The key idea behind SpecInfer is leveraging small speculative models to predict the LLM's outpu

arxiv.org

 

이 글은 전에 리뷰한 Fast Inference from Transformers via Speculative Decoding 논문을 읽고 오시는 걸 추천합니다.

 

이 글 작성 시점에는 이 논문을 한글로 리뷰하거나 설명한 글이 없기에 다른 한국인 분들이 이 논문을 읽을 때 편하게 읽으시라고 최초로 한글로 정리해서 올립니다. 다른 개인적인 목적을 위해서도 올립니다. 아무쪼록 논문 읽는데 도움 되시길 바랍니다.

 

Abstract

이 논문은 기존의 방식, 근사 모델 = Approximation Model = $M_{q}$ = SSM(Small speculative model)을 이용해 하나의 토큰 시퀀스만 생성했던 것과 다르게 top-k 방식을 이용해 토큰 트리를 생성하고 더 나아가 여러 개의 SSM을 이용해 여러 개의 토큰 트리를 병합하는 방식으로 토큰 트리를 생성해 이것을 검증하는 방식인 Tree-based Speculative Inference and Verification (트리기반 추측 추론 및 검증)을 소개하는 논문입니다. 이 글에서 근사 모델을 지칭하는 단어는 이 논문과 똑같이 SSM 으로 통일하겠습니다.

 

 

1. Introduction

 

이 챕터에서는 이전의 방식과 이 논문의 Tree-based 방식을 비교하고있습니다.

 

Incremental Decoding (증분적, 서서히 증가하는 디코딩) 은 기존의 LLM 이 토큰을 생성하는 방식입니다. 이 방식은 토큰을 하나 생성할 때마다 Autoregressive(자기 회귀) 방식으로 생성하기 때문에 최적화되지 않은 런타임 성능 및 GPU 활용 때문에 컴퓨팅 자원이 남습니다.

 

Sequence-based Speculative Inference는 이 전 논문에서 사용했던 방식인 기존의 하나의 시퀀스만 고려한 Speculative Inference 방식입니다. SSM(Small Speculative Model)과 LLM(Large Language Model) 간의 모델 용량 차이로 
인해 두 모델이 생성한 토큰 시퀀스가 잘 정렬(align) 되지 않는다.라고 말하고 있습니다.

여기서 "정렬되지 않는다" 라는 소리는 SSM은 메모리와 계산 요구를 낮추기 위해 LLM보다 훨씬 작은 모델입니다. 그러나, 이로 인해 SSM이 생성한 토큰 시퀀스와 LLM이 생성한 토큰 시퀀스 간의 일치율이 낮다는 뜻입니다. 다시 말해 SSM이 생성한 단일 토큰 시퀀스가 LLM과 잘 맞지 않는 경우, SSM이 추측한 결과는 LLM의 실제 결과와 크게 다를 수 있습니다. 이로 인해 추론 성능이 저하된다는 뜻입니다.

 

Tree-based Speculative Inference 가 이 논문에서 새롭게 제안하는 트리기반의  Speculative Inference 방식입니다. 다양한 추측 후보들을 고려하기에 이러한 정렬문제를 해결합니다. 

 

Tree-based Speculative Inference를 구현하기 위해 두 가지 도전과제가 있습니다.

  1. SpecInfer는 매우 큰 검색 공간($502724^4 ≈ 6 × 10^{18}$)을 탐색해야 합니다.
  2. SpecInfer는 Incremental Decoding과 동일한 확률 분포를 사용하여 Stochastic decoding ( 확률적 디코딩 )에서 추론된 토큰을 검증해야 합니다. 

 

1번 에서의 502724는 논문에서 예시로든 OPT 모델군의 어휘 50272개 를 뜻 합니다. 모델이 처리할 수 있는 고유한 토큰(단어 또는 기호)의 총 수를 의미합니다.

2번은 다시 말해 기존이 Incremental Decoding에서 사용은 확률적 디코딩 방식 (ex top-k) 방식으로 내뱉는 확률 분포와 동일한 분포를 가지면서 토큰 트리를 검증해야 된다는 뜻입니다.

확률적 디코딩 방식 은 ChatGPT 등 생성형 언어모델을 제공하는 서비스들은 같은 프롬프트가 들어와도 Greedy decoding을 써 같은 답변이 계속 나오는 것이 아니라 Stochastic decoding (확률적 디코딩)을 써서 답변이 조금씩 랜덤 하게 달라지는 것이 예시로 들 수 있겠습니다.

 

 

2. SpecInfor’s Overview

SpecInfor의 전체적인 알고리즘입니다. 처음 input으로 "machine" 토큰이 들어오면 Expansion-based 방식이나 Merge-based 방식 중 하나의 방식으로 Token Tree를 생성합니다. 이 Token Tree를 Tree-based Parallel Decoding을 이용해 병렬적으로 각 노드의 다음 토큰의 분포를 얻습니다. 각각의 분포를 이용해 트리의 각 노드를 검증합니다.

 

 

3. Learning-based Speculator - How to generate a token tree

논문에서는 3번째 챕터 제목을 Learning-based Speculator로 해놨습니다. 이 챕터에서는 토큰 트리를 어떻게 생성하는지 설명하고 있습니다.토큰 트리를 생성하는 방식은 Expansion-based 방식과 Merge-based 방식, 2가지가 있습니다. 그전에 정의를 하고 넘어가겠습니다.

정의 3.1
예시 1

$N$ 은 토큰트리이며 $u$ 는 토큰트리에 속한 노드 들입니다. $t_u$ $u$ 노드가 표현하는 토큰입니다. $P_u$는 $u$ 노드의 부모노드를 가리킵니다. $S_u$ 는 $u$ 노드 시점에 토큰과 그전에 선택된 모든 토큰의 합이고 이는 하나의 토큰 시퀀스를 나타냅니다.

 

 

 

Expansion-based token tree construction

각 step 마다 top-k를 쓰는 방식입니다. 위의 그림 예제는 step 0, step 1에서는 top-2를 썼고 step 2에서는 top-1을 써 토큰 트리를 생성했습니다. 이를 논문에서는 간결하게 Expansion configuration <2,2,1>이라고 표기하고 있습니다. 그림 예제의 알고리즘을 순서대로 설명하겠습니다.

 

1. "machine" 토큰을 SSM모델에 input으로 넣고, output으로 생성된 분포에서 가장 나올 확률이 높은 토큰 두 개  즉, top-2 인 "learning", "translation" 토큰을 생성

2. "learning", "translation"을 각각 SSM모델 input으로 넣어 top-2 토큰을 생성

3. top-1 이기에 가장 확률이 높은 토큰 하나만 생성

 

k 높을수록 후보가 많이 생성되니 수락률이 높아집니다만 그만큼 컴퓨팅 리소스도 많이 잡아먹습니다.

 

 

Merge-based token tree construction

 

Merge-based 방식은 위의 Expansion-based 방식을 병렬적으로 여러 개 실행시켜 여러 개의 토큰 트리를 생성해 병합하는 방식입니다. 각 SSM 은 각각 토큰 트리를 내뱉습니다. 다만 가중치가 같은 SSM 모델로 여러개의 토큰트리를 생성한다 한들 같은 분포를 생성하기에 다양성이 부족합니다. 따라서 저자는 adaptive boosting 방식으로 재학습시킵니다. 알고리즘은 다음과 같습니다.

 

  1. 말뭉치(corpus)에서 프롬프트 샘플을 생성합니다.
  2. 프롬프트 샘플을 LLM input으로 사용하여 토큰 시퀀스를 생성합니다.
  3. SSM 0을 미세 조정(fine-tune)합니다. 이때 SSM 0과 LLM이 동일한 후속 토큰을 생성한 모든 프롬프트 샘플을 표시합니다.
  4. 표시되지 않은 프롬프트 샘플을 사용하여 SSM 1을 미세 조정(fine-tune)합니다.
  5. 이 과정을 모든 SSM에 대해 반복합니다.

이를 통해 분포의 다양성을 확보합니다. 학습 이후 추론과정에서 각각의 SSM 들은 병렬적으로 실행되어 그림의 예시로는 3개의 토큰 트리가 생성됩니다. 이 3개의 토큰트리는 다음과 같은 정의로 병합됩니다.

 

정의 3.2

 

 

$\mathcal{M}$은 모든 토큰 트리의 병합 트리입니다. 모든 토큰 트리의 노드의 토큰 시퀀스는 $\mathcal{M}$에 존재하는 노드의 토큰 시퀀스와 같아야 하고 그 역도 성립해야 합니다.

직관적으로, 각 토큰 트리는 토큰 시퀀스 집합을 나타냅니다. 여러 토큰 트리를 병합하면 원래의 모든 트리의 토큰 시퀀스를 포함하는 새로운 트리가 생성된다는 뜻입니다. 여러 개의 트리가 주어졌을 때 우리가 흔히 생각할 수 있는 병합 방식입니다.

 

4. Token Tree Verifier

토큰 트리를 생성하였다면 이를 검증해야 합니다. 이 챕터에서는 그것을 설명합니다.

 

 

4.1 Tree Attention

검증을 하려면 토큰 트리의 모든 토큰 시퀀스를 LLM에 input으로 넣고 분포를 얻어야 합니다. 다시 말해 "모든 토큰 시퀀스는 LLM의 self-attention 구조를 통과해야 된다는 뜻"이고 이는 "LLM에서 모든 토큰 시퀀스의 attention output을 구해야 한다는 뜻"과 같습니다. 이를 위해 Tree Attention을 정의합니다.

정의 4.1

Tree Attention 은 토큰 트리 $N$ 의 모든 노드들의 토큰 시퀀스들의 attention입니다.

 

 

4.2 Tree-based Parallel Decoding

이 챕터에서는 LLM에서 Tree Attention을 계산하는 2가지 방식을  소개합니다.

 

여기 예시의 토큰 트리가 주어지고 Sequence-based Parallel Decoding와 Tree-based Parallel Decoding으로 Tree Attention를 구하는 그림이 있습니다. 

 

Sequence-based Parallel Decoding

가장 기본적으로 생각할 수 있는 방식입니다. 예시의 토큰 트리는 리프노드 3개 즉, 자식 노드가 없는 t5, t7 t9 노드의 토큰 시퀀스를 3개의 kernel을 두고  Autoregressive 하게 처리합니다. 이는 t2, t3 토큰시점의 KV-cache 가 중복되는 것을 볼 수 있고 이는 연산의 중복이라고도 볼 수 있습니다. 

 

Tree-based Parallel Decoding

따라서 저자는 Depth-first search to update key-value cache, Topology-aware causal mask을 제안하고 이를 이용한 방식인 Tree-based Parallel Decoding을 제안합니다. 

 

Depth-first search to update key-value cache는 흔히 알고 있는 깊이 우선 탐색의 순서로 연산해 중복 cache를 피합니다. Figure 4의 토큰 트리의 번호 순서가 깊이 우선 탐색 방식으로 순서를 매긴 것 과 같습니다. 이 순서로 하나의 Kernal에서Autoregressive 하게 연산을 처리하면 중복된 KV-cache 및 연산이 없습니다.

 

Topology-aware causal mask는 그림만 봐도 대충 이해하실 텐데 $t_6$, $t_7$ 토큰은 예시 토큰 트리에서 보면 $t_5$ 와 다른 분기이기에 기존의 causal mask의 역할과 비슷하게 다른 분기의 attention weight에 마이너스 무한대를 곱해줍니다. 이는 트리 기반 attention 계산을 하나의 Kernal에서 수행하면서 인과 관계를 유지하게 해 줍니다.

 

 

 

이 챕터를 읽을때 저는 병렬적으로 처리하는 부분의 설명이 정확히 되어있지 않아 여러 의문점이 드는 부분이 많았습니다. 논문 설명에 관련된 내용이 아니니 읽지 않으셔도 됩니다.

 

더보기

의문점 1. Figure 4를 보면 Sequence-based Parallel Decoding 은 연산 순서가 화살표로 표시되어 있고 Tree-based Parallel Decoding는 컴마로 나눠져 있다. 병렬처리의 다름을 설명하는 것인가?

 

의문점 2. 병렬적으로 처리한다고 해도 여러 방식이 존재할 텐데 정확히 어떠한 방식으로 병렬적으로 알고리즘이 돌아가는가?

 

이는 논문에서도 쓰여있지 않습니다. 따라서 저는 3개의 병렬처리 알고리즘을 예상했고 이는 다음과 같습니다.

 

1번은 모든 토큰을 한 번에 병렬적으로 연산합니다. 이러면 KV-cache를 활용을 못하며 연산도 중복됩니다. 추측하건대 이는 Fast Inference from Transformers via Speculative Decoding 논문의 병렬처리 방식 같습니다. SpecInfer 논문에서는 KV-cache 활용을 계속 언급하고 있기에 이런 방식을 고려하지 않는 것 같습니다. 따라서 제외.

 

2번은 Tree-based Parallel Decoding에서 각 토큰을 컴마 표기로 나눠져 있기에 "이렇게 연산한다는 건가?" 가정이 들어 제안했습니다. 다만 1번과 동일하게 이러면 KV-cache 활용을 못합니다. 또한 t2 나 t3 토큰 시점 노드의 후보 토큰 분포는 어떻게 얻고 검증하는지 의문입니다. 모델 output 구조를 바꾸면 이러한 input 이 가능해 보입니다. 어찌 보면 가장 병렬적으로 연산중복 없이 연산하는 거겠지요. 이는 새로운 연구 방향성 일수도 있겠습니다. 이 논문에서 이러한 input을 쓴다는 것 인지 확실치 않지만 일단 이 논문은 KV-cache 활용한다는 점은 명확하기에 이것도 제외.

 

3번이 가장 유력해 보입니다. KV-cache 도 활용가능하며 Depth-first search to update key-value cache, Topology-aware causal mask 두 개를 써 Kernal 1개로 통합하면 중복된 연산과 KV-cache 도 없습니다. 다만 "KV-cache을 활용하기 위해서 Autoregressive 하게 처리한다는 소리인데 이것을 병렬 처리라고 말하고 Parallel Decoding이라고 부를 수 있는 것인가?"라는 큰 의문점이 있습니다.

 

의문점 3. 3번의 방식으로 병렬처리를 한다고 치고 Kernal 하나로 통합해 연산한다고 가정한다면 또다시 두 가지 방식을 생각해 볼 수 있습니다.

 

1번은 말 그대로 그냥 순서대로 Autoregressive 하게 처리합니다.

2번은  Autoregressive 하게 처리함으로써 KV-cache 도 활용하되 연산 중복이 안 되는 선에서 병렬적으로 처리합니다.

 

이러한 의문점들을 해결하고 정확한 병렬처리 알고리즘이 궁금하신 분은 오픈소스로 코드가 공개되어 있으니 그 코드를 분석해야 할 것 같습니다. 

 

 

 

4.3 Token Verification

효율적으로 Tree Attention을 구했다면 우리는 최종적으로 각각의 노드들을 검증할 수 있는 분포를 를 얻습니다. 예시 1 사진으로 설명하자면 $u_1$ 에서  "<start>" 토큰을 LLM input으로 넣었을 때 다음에 어떤 토큰이 나올지 후보 토큰들을 표현하는 확률 분포를 얻게 되겠죠. "The" 토큰이 확률이 높을 수도 "The", "She"가 아닌 "He"라는 토큰이 나올 확률이 90% 일수도 있겠습니다. $u_1$ 말고도 $u_2$, $u_3$ ... 등등 각각 노드의 LLM 확률 분포를 얻게 됩니다. 이를 가지고 저자는 두 가지 검증 방식을 소개합니다.

 

VerifyGreedy (탐욕적 검증)

Greedy 방식으로 검증하는 방식은 간단하게 LLM 확률 분포에서의 top-1 토큰이 자식 노드에 있으면 수락하고 없으면 LLM 확률 분포에서의 top-1 토큰을 수락해 리턴합니다.

예시 1

그림으로 예시를 들자면 $u_1$ 노드에서 LLM이 다음 토큰을 "The"가 가장 확률이 높게 나올 토큰으로 예측하면 "The" 토큰을 수락하고 자식노드 검증 반복 합니다.

"The", "She"가 아닌 "He"가 가장 확률이 높게 나올 토큰으로 예측하면 "He" 토큰을 수락하고 검증 종료 후 SSM을 이용해 토큰 트리를 다시 생성합니다.

 

VerifyStochastic (확률적 검증)

VerifyStochastic 함수는 Multi-step Speculative Sampling 즉, 각각의 노드를 for 문으로 돌면서 Speculative Sampling을 하는 알고리즘과 같습니다.

이는 Fast Inference from Transformers via Speculative Decoding 논문을 읽고 Speculative Sampling 알고리즘을 아셔야 이해할 수 있습니다. (Fast Inference from Transformers via Speculative Decoding 글의  2. Speculative Decoding 부분 참고)

 

Speculative Sampling 알고리즘
두 로직의 유사성

Fast Inference from Transformers via Speculative Decoding 논문의 Speculative Decoding 알고리즘 중 Speculative Sampling 부분과 유사함을 그림으로 표현했습니다.

Speculative Decoding의 Speculative Sampling 은 거부 될 때 집합에 포함되므로 $  r_i >  \frac{p(x)}{q(x)} $ 이고 VerifyStochastic의 Speculative Sampling은 수락 될때를 if 문으로 표현했기에 $  r_i  \leq  \frac{p(x)}{q(x)} $ 입니다.

즉, 거부될때를 표현한 식, 수락할때 표현한 식 이므로 부호가 반대되어있습니다. 또한 $  r_i  \leq  \frac{p(x)}{q(x)} $ 와 $  r_i  > 1 - \frac{p(x)}{q(x)} $ 는 확률적으로 선택하는 식으로써 확률이 같습니다. (Fast Inference from Transformers via Speculative Decoding 글의  2. Speculative Decoding 부분 참고)

 

결론적으로 표현한 수식만 다르고 본질적인 알고리즘은 동일합니다. 

 

여기서 중요한 점은 Multi-step으로  Speculative Sampling을 하기 때문에 Speculative Sampling과 동일하게 검증할 때 분포가 바뀌지 않고 유지됩니다. 이는 Introduction에서 언급한 2번째 도전과제를 Multi-step Speculative Sampling으로 해결한 것 이라고도 볼수있습니다.

 

 

5. System Design and Implementation

 

 이 챕터에서는 SpecInfer의 4가지 특이사항을 소개합니다.

 

 

  • SpecInfer는 Megatron-LM에서 도입된 하이브리드 병렬화를 사용하여 LLM을 제공합니다.
  • SpecInfer는 Orca에서 도입된 연속 배치(Continuous Batching)를 사용합니다.
  • SpecInfer는 FlexFlow 위에 구현되었습니다.
  • SpecInfer는 FasterTransformer를 기반으로 한 맞춤형 CUDA 커널을 사용합니다.

그리고 몇 가지 SpecInfer을 사용함으로써 발생되는 Overhead를 말하며 반박합니다.

 

메모리 오버헤드

  • 하나 이상의 SSM(Small Speculative Model)의 매개변수를 저장하기 위해 메모리를 할당해야 합니다. - SSM은 전체 메모리 요구량을 1% 미만으로 증가시키기에 무시할 수 있습니다.
  • LLM으로 단일 토큰을 디코딩하는 대신 토큰 트리를 검증하는 데 사용됩니다. - 매우 긴 시퀀스 길이에 대한 키-값 캐시와 비교할 때, 토큰 트리의 메모리 오버헤드는 무시할 수 있습니다.

계산 오버헤드

  • SSM을 Incremental Decoding 모드로 실행해야 하며  전체 토큰 트리에 대해 Tree Attention 출력을 계산해야 합니다.  - Incremental Decoding에서 LLM(Large Language Model)을 제공할 때 GPU의 계산 자원이 남으므로 계산 오버헤드는 무시할 수 있습니다.

 

 

6. Evaluation

분산 컴퓨팅에서 모델들 비교
Offloading 방식에서 모델들 비교

 

성능 비교표입니다. 분산 컴퓨팅 비교를 보시면 GPU와 node 가 증가해 병렬성이 증가할수록 SpecInfer의 효율성이 극대화됩니다. Expansion configuration 은 ⟨1,1,3,1,1,1,1,1⟩입니다.

 

 

VerifyGreedy 함수와 VerifyStochastic 함수 비교
Batch-size 별로 Tree width 비교 Expansion configuration⟨1,1,𝑘,1,1,1,1,1⟩ 에서 k = Tree width

 

Tree Attention 계산 방식 비교

 

7. RelatedWork, 8. Conclusion 은 별다른 내용이 없기에 넘어가겠습니다. A Artifact Appendix 부분은 흥미로운데 관심 있으신 분들은 읽어보세요.

[2211.17192] Fast Inference from Transformers via Speculative Decoding

 

Fast Inference from Transformers via Speculative Decoding

Inference from large autoregressive models like Transformers is slow - decoding K tokens takes K serial runs of the model. In this work we introduce speculative decoding - an algorithm to sample from autoregressive models faster without any changes to the

arxiv.org

 

대학원 합격 후 이 논문 발표 할 일이 생겼는데 이후 검색해 보니 이 글 쓴 시점에서는 처음부터 끝까지 한국어로 잘 정리된 글이 없는 것 같아 다른 한국인 분들이 이 논문을 읽을 때 편하게 읽으시라고 정리해서 올립니다. 다른 개인적인 목적을 위해서도 올립니다. 아무쪼록 논문 읽는데 도움 되시길 바랍니다.

Abstract

 

이 논문은 출력에 변화 없이  Autoregressive(자기 회귀) 모델에서 더 빠르게 샘플링할 수 있는 Speculative Decoding (추측 디코딩) 알고리즘 소개합니다.

 

1. Introduction

 

이 논문은 처음에 대형 자동회귀 모델(대형 트랜스포머, LLM)의 추론 속도를 높이기 위한 여러 접근 방식을 소개합니다.

  1. 모든 입력에 대해 동일하게 추론 비용을 줄이는 것을 목표로 하는 방식
  2. 더 쉬운 추론 단계에서는 더 적은 계산 리소스를 사용하는 것을 목표로 하는 방식

그러나 이러한 접근 방식들은 일반적으로 모델 아키텍처를 변경하거나, 학습 절차를 수정하고, 모델을 재학습해야 하며, 동일한 출력을 유지하지 못하는 경우가 많기에 안 좋다고 주장하며 동시에 이러한 이전 연구 논문에서 얻은 통찰을 말해줍니다.

  1. 추론 중 일부는 "어렵고", 일부는 "쉽습니다. 
  2. LLM의 병목 현상은 산술연산이 아닌 메모리 대역폭과 통신에 있습니다. 따라서, 추가적인 계산 리소스가 사용 가능할 수 있습니다.

부연설명을 해보자면 1. 에서말하는 일부 추론이 쉽다는 뜻은 언어모델이 "The car is blue"를 예측할 때 "The car"이란 앞단 토큰은 쉬운 추론. 뒷단의 "is blue"는 어려운 추론이라고 말하고 있습니다.. 어찌 보면 당연합니다.

그리고 2. 에서 말하는 통신은 여러 개의 GPU를 사용하거나 분산컴퓨팅을 하게 되면 각 GPU/컴퓨터 간 통신을 해야 하는데 그것을 뜻합니다. 보통의 대형 모델은 여러 GPU나 여러 컴퓨터를 이용해 연산합니다.

 

두 개의 통찰은 다음과 같은 직관으로 연결됩니다.

  1. 적응적인 계산량을 사용하면 더 빠른 추론이 가능합니다.
  2. 남은 컴퓨팅 리소스를 사용하여 병렬 처리를 늘릴 수 있습니다

또한 "이 두 개의 직관의 실현을 달성할 수 있는 것 은 Speculative execution 기술이다."이다 말하며 Speculative execution을 기반한 Speculative Decoding을 소개합니다.

 

여기서 Speculative execution 이란?

다음과 같은 코드가 있고 x > 10 인지 판단하기 위해 x 값을 읽는 데 많은 시간이 필요하다고 가정 (병목현상), 이때 CPU를 그냥 내버려 두는 것이 아니라 x 값을 읽을 때 y = 3 을 먼저 실행시켜 CPU의 자원이 놀지 않게 하는 것이 Speculative execution 기술입니다.

여기선 y = 3 이란 코드 단 한줄로 예시를 들었지만 극단적으로 if 문 뒤에 엄청나게 많은 코드들이 존재하고 x 값이 엄청나게 큰 리스트라 읽는 데에 엄청 많은 시간이 필요하다고 생각하면 됩니다.

그래서 이걸 기반한 Speculative Decoding 은 뭐냐?

그림 1. 각 선은 알고리즘의 한 반복을 나타냅니다. 초록색 토큰은 근사 모델(6M 파라미터를 가진 GPT-like Transformer 디코더)이 제안한 토큰으로, 목표 모델(97M 파라미터를 가진 GPT-like Transformer 디코더)이 받아들인 것입니다. 반면 빨간색파란색 토큰은 거부된 제안과 그에 대한 수정 사항을 나타냅니다. 예를 들어 첫 번째 줄에서 타겟 모델은 한 번만 실행되었고, 5개의 토큰이 생성되었습니다.

 

다음 그림과 같은데 첫번째줄을 설명하자면

"[START]" 토큰 들어오고 Autoregressive(자기 회귀) 기법으로 근사 모델을 돌려 "[START] japan' s benchmark bond"까지 토큰을 5개 생성합니다. 근사 모델은 매우 작은 Transformer 모델이기 때문에 매우 빠르게 5개의 토큰 생성합니다.

그리고 매우 큰 모델(목표 모델)에게

  1. [START]
  2. [START] japan
  3. [START] japan'
  4. [START] japan' s
  5. [START] japan' benchmark
  6. [START] japan' benchmark bond

다음과 같은 6개의 문장을 병렬적으로 넣습니다. 그러면 총 6개의 분포가 나옵니다.

 

첫번째 "[START]" 를 목표 모델 input 으로 넣고 나온 output  분포로 2번째  토큰이 " japan " 이 맞는지 검사합니다.

두번째 "[START] japan" 를 목표 모델 input 으로 넣고 나온 output 분포로 3번째 토큰이 " ' " 이 맞는지 검사합니다.

 

도중에 검사가 "틀린 것" 이 있으면 그 이후의 모든 토큰은 "틀린 것" 으로 간주합니다.

그림 예시로는 5번째 input 으로 얻은 output 분포로 "bond" 토큰이 틀렸다고 나왔고 이후 토큰이 없으니 output 분포의 정답 토큰인 "n" 토큰을 생성합니다.

임의 예시로 4번째 input 으로 얻은 output 분포로 "benchmark" 토큰이 틀렸다고 나오면 " bond" 토큰도 자동적으로 "틀린 것" 으로 간주하고  output 분포의 정답 토큰을 생성합니다.

 

틀린지 맞는지 검사 는 "Speculative Sampling" 으로 검사함 로직은 다음 챕터에서 설명하겠습니다.

여기서 알아야할 매우 중요한 3가지는 

 

  1. 이미 어떤 토큰을 검사할지 다 알고있기 때문에 병렬적으로 검사 가능. ( 목표모델이 Autoregressive 하게 돌아가지않고 병렬적으로 돌아감)
  2. 모델 아키텍처를 변경하지 않고 훈련 절차를 변경하거나 모델을 재훈련할 필요도 없으며 모델 출력 분포를 변경하지 않음
  3. 초록색 토큰 즉, 앞단의 토큰들은 "쉬운 작업" 이기에 작은 근사모델로도 충분히 예측 가능함

 

 

2. Speculative Decoding

 

Speculative Decoding 알고리즘 중 Speculative Sampling 알고리즘이 포함되어있고 이 Sampling 이 중요하기에 먼저 설명하겠습니다. 수식 정의도 하겠습니다.

정의

 

$p(x)$, $q(x)$ 는 각 모델에 $t$ 시점 이전의 모든 토큰이 input 일때 $t$ 시점 토큰의 확률분포 입니다. 이 정의는 계속 쓰이니 기억해두시길 바랍니다.

 

Speculative Sampling 알고리즘
예시

 

예시를 차근차근 보시면 충분히 이해하실것 같습니다. p와 q의 값이 차이가 많을수록 ( 근사모델이 잘 예측 못할수록) 토큰을 거절 할 확률은 높아지는것을 볼수있습니다. 거절 됐을때 변형된 분포에서 토큰을 생성(sampling) 하게 되는데 그 분포를 보시면 이미 거절한 토큰 "He" 는 0% 으로 만들어 다시 나오지 못하게 합니다.

 

토큰 거절의 로직이 저런 이유는 output 분포를 유지시키기 위해서 저런 방식을 쓰고있습니다. 논문 맨 뒤 11페이지를 보시면 설명하길  Rejection Sampling 을 기반한것 같습니다.

 

Speculative Decoding 알고리즘

 

이제 논문의 Speculative Decoding 알고리즘 을 보시면 이제 이해가 쉽게 되실겁니다.

  1. 근사 모델로 $ \gamma $ 만큼  Autoregressive 하게 토큰 및 분포 생성
  2. 생성된 토큰들을 prefix(이전 토큰) 들과 합해 각각 병렬적으로 목표 모델에 넣어 분포 생성
  3. 근사 모델 분포와 목표 모델 분포로 Speculative Sampling 을 돌려 토큰 검사 및 토큰 생성

중간의 Determine the number of accepted guesses n. 부분은 오른쪽 예시를 보시면 알겠지만 Speculative Sampling 알고리즘 중 일부를 수학적으로 표현한 것 입니다.

$  r_i >  \frac{p(x)}{q(x)} $ 부분이 이해가 안되실수있는데

 

$r_{i} $는 균등분포로 0~1 사이의 값을 뽑는거니 수학적으로 랜덤 선택을 뜻 합니다. 또한 Speculative Sampling 은 $ 1-\frac{p(x)}{q(x)} $ 확률로 토큰 거절을 합니다. 이것을 수학적으로 표현하면 $ r_i \leq 1- \frac{p(x)}{q(x)} $ 입니다.

 

$ r_i \leq 1- \frac{p(x)}{q(x)} $ 과  $  r_i >  \frac{p(x)}{q(x)} $ 은 확률적으로 선택하는 식으로써 확률이 같습니다.

 

 

3. Analysis

이 챕터에서는 두 모델이 주어졌을때 가장 효율적으로 Speculative Decoding 을 쓰기위해 최적의 $\gamma$ 값을 찾아가는 수학적 여정을 다룹니다. 따라서 다양한 기대값을 정의하고 Speculative Decoding 사용할때와 사용하지않았을때의 연산량 및 실행시간 비율을 구합니다.

 

정의
예시

 

$\beta$ 는 "수락률" $\alpha$ 는 "수락률의 기대값" 이라고 정의하고있습니다.

"생성될 토큰의 기대값" 인 $E(\#generated \ tokens)$는 제한된 기하 분포(Capped Geometric Distribution) 의 기대값 에서 유래됬습니다. 이게 뭔소리나면

 

Speculative Decoding 알고리즘은 

  1. 최대 𝛾개의 토큰을 병렬로 예측합니다.
  2. 모든 토큰이 실패 없이 승인되면 최대 𝛾+1개의 토큰을 생성합니다.
  3. 첫 번째 실패(거부)가 발생할 때까지 성공적으로 승인된 토큰의 개수를 셉니다.

입니다. 이를 수식으로 표현하면 다음과 같고

이를 정리하면

 

우리가 보던 식이 나옵니다.

 

3.3. Walltime Improvement

 

 

이 챕터에서는 Speculative Decoding 사용할때와 사용하지않았을때의 실행시간 개선 비율을 구합니다.

 

 

정의
정리

 

실행시간 개선 비율을 구하는 수식은 위와 같은데 이를 풀면 아래와 같습니다.

재해석

결국에는 "(Speculative decoding 알고리즘 시간) 대비 (원본 디코딩 알고리즘 시간) 이 몇배인가?" 를 물어보는 수식입니다. 1배 초과이면 Speculative decoding을 사용함으로써 개선이 된거고 1배 미만이면 Speculative decoding을 사용함으로써 오히려 악화 된것이라고 볼수있습니다.

 

증명

 

논문에서는 위와같은 증명을 써놨는데 저가 느끼기에 설명이 와닿지 않습니다. 따라서 2~3일 고민끝에 저만의 방식으로 재해석 한것이니 참고하세요.

 

3.4. Number of Arithmetic Operations

 

 

이 챕터에서는 Speculative Decoding 사용할때와 사용하지않았을때의 산술연산 비율을 구합니다.

 

 

정의
정리
재해석

 

정리 3.11 식 또한 풀면 "(원본 디코딩 산술연산) 대비 (Speculative decoding 산술연산) 이 몇배인가?" 를 물어보는 수식입니다. 3.8 정리와 반대로 1배 초과이면 Speculative decoding을 사용함으로써 연산량이 늘어난것이고 1배 미만이면 Speculative decoding을 사용함으로써 연산량이 줄어들었다는 의미입니다.

 

증명

똑같습니다. 논문에서는 위와같이 증명이 써져있습니다. 다만 잘 와닿지 않습니다. 따라서 2~3일 고민끝에 저만의 방식으로 재해석 한것이니 참고하세요.

3.5. Choosing $\gamma$

 

 

 

결국에는 목표 모델 과 근사 모델이 주어졌을때 $T_p$, $T_q$, $\hat{T_p}$, $\hat{T_q}$ 을 구하고 여러 output 분포 표본을 이용해 $\alpha$ 값을 구한뒤 3.8 과 3.11 수식으로 최적의 $\gamma$ ( optimal $\gamma$ )를 구할수있습니다.

 

 

또한 위와같은 실행 시간 타임라인 표현한 그림 예시를 보여줍니다. 다만 저가 임의로 아래와같이 좀 더 정확하게 표현해봤습니다.

 

M_p decoder 의 연산량은 그림에 표현된 크기와 상관없이 동일

 

이 챕터 마지막에서는 실행 중 $\gamma$ 값을 동적으로 조정함으로써 성능을 더욱 최적화할 수 있다고 말합니다. 이게 뭔의미냐면

$\alpha$ (수락 확률 기대값)는
전체적으로 근사 모델 $M_{q}$가 목표 모델 $M_{p}$에 의해 수락될 확률의 평균값입니다.
특정한 문맥에 상관없이, 모델이 얼마나 잘 추측하는지를 나타내는 전반적인 평가 기준입니다.
즉, 고정된 값으로 간주될 수 있습니다. 이 논문에서는 $\alpha$ 을 기반으로 $\gamma$ 를 정했습니다. 
$\beta$ (수락 확률)는
특정 입력 문맥에서 실행 도중 측정되는 실제 수락률입니다.
문맥에 따라 다를 수 있으며, 문장의 초반부와 후반부에서 다르게 나타날 수 있습니다.
즉, 동적으로 변화하는 값입니다.

실행 중에 $\gamma$ 값을 조정한다는 의미는 모델이 실행 중에 현재 문맥에 대한 실제 수락률 $\beta$ 값을 모니터링하고,이 값을 기반으로 $\gamma$ 값을 동적으로 조정하여 성능을 더 최적화 할수있다고 하며 추가연구 사항이라고 말합니다.


예제
초기: 모델의 수락률이 높음 ($\beta$≈0.9) 모델이 정확하므로 한 번에 더 많은 토큰을 병렬로 예측 가능 → $\gamma$ 증가
중반: 문맥이 복잡해져서 수락률이 떨어짐 ($\beta$≈0.6) 불확실성이 증가하므로 병렬 샘플링을 줄이고 신중하게 진행 → $\gamma$ 감소
후반: 모델의 수락률이 다시 안정됨 ($\beta$≈0.8) 병렬 샘플링을 다시 증가시켜 속도를 최적화 → $\gamma$ 재조정
이러한 방식으로 실행 중 동적으로 조정하는 것이 고정된 $\gamma$ 값보다 더 나은 성능을 제공할 수 있습니다.

 

4. Experiments

결국에는 trade-off  형식으로 성능 향상이 되었다는 것을 확인가능하네요.

 

2025-02-02 경 논문 읽고 난 뒤 글쓴이가 제안한 추가 아이디어

 

 
 

$ \ alpha $$

 

 

1. Confusion Matrix (혼돈 행렬)

https://en.wikipedia.org/wiki/Confusion_matrix

True Positivie(TP) : 실제값이 Positivie 인데, 예측값 Positivie라고 둘 값이 같은 경우

(고양이 사진을 보여주고 모델이 고양이라고 추측함)

False Positive(FP) 실제값이 Negative 인데, 예측값 Positivie라고 둘 값이 다른 경우

(개 사진을 보여주고 모델이 고양이라고 추측함)

True Negative(TN) 실제값이 Negative 인데, 예측값 Negative라고 둘 값이 같은 경우

(개 사진을 보여주고 모델이 라고 추측함)

False Negative(FN) 실제값이 Positivie 인데 예측값 Negative  라고 둘 값이 다른 경우

(고양이 사진을 보여주고 모델이 라고 추측함)

 

고양이를 Positivie 개를 Negative 라고 정의한 것은 임의로 정한것임

 

개를 Positivie 고양이를 Negative

고양이가 존재하면 Positivie 고양이가 없으면 Negative

암이 있으면 Positivie 암이 없으면 Negative

 

여러 방법으로 이진분류를 임의로 정할수 있음.

 

암이 있으면 Positivie암이 없으면 Negative 예시

True Positivie(TP) : A에게 실제로 암이 있는데 모델이 암이 있다고 추측한 경우

False Positive(FP) : A에게 실제로 암이 없는데 모델이 암이 있다고 추측한 경우

True Negative(TN) :A에게 실제로 암이 없는데 모델이 암이 없다고 추측한 경우

False Negative(FN) : A에게 실제로 암이 있는데 모델이 암이 없다고 추측한 경우

 

 

True Positivie(TP) False Positive(FP) True Negative(TN) False Negative(FN) 의 앞의 값 True False 가 의미하는것은 실제 값과 예측 값을 비교했을때 같냐 아니냐 를 뜻함.

 

실제값 = 예측값 : True

실제값 ≠ 예측값 : False

 

ex

실제값 Positivie, 예측값 Positivie = True

실제값 Negative예측값 Negative = True

실제값 Positivie, 예측값 Negative = False

실제값 Negative예측값 Positivie = False

 

True Positivie(TP) False Positive(FP) True Negative(TN) False Negative(FN) 의 뒤의 값 Positive, Negative가 의미하는 것은 예측값을 그대로 반환

 

예측값이 Positive 일경우 Positive

예측값이 Negative 일경우 Negative

 

 

Confusion Matrix 정의해서 쓰는 TP,FP,TN,FN 단어들은 실제 값 나타내는 부분 없음

예시로 있다는 가정하에 맨뒤 3번째 값에 실제값을 반환하는 새로운 Version 을 정의하면

위에 표는 존재하지 않는 표로 임의로 만들어낸것임

이렇게 실제 값을 보기쉬운 New Confusion Matrix를 만들수있겠지만

다만 현재쓰는 TP,FP,TN,FN 정의는 실제 값 을 나타내는 부분은 없으므로

TP,FP,TN,FN 을 보고 실제값을 예측해야함

 

ex

False Negative(FN) 일 경우 실제 값을 예측해보면

뒤에 값 이 Negative 이므로 예측값은 Negative

앞의 값 이 False 이므로 실제값 ≠ 예측값

 

따라서 False Negative(FN)의 실제값은 Positive

 

다중분류에서의 Confusion Matrix

만약 개 고양이분류가 아닌 개 고양이 오리 말 4마리를 분류하는 다중분류 일 경우

이 모델의 Confusion Matrix 는 어떻게 표현할까?

다중 분류에서의  혼돈행렬은 개, 고양이, 오리, 말 중 무엇을 기준으로 두느냐의 따라 표를 4개 만들수있다.

(기준에 따라서 표를 다르게 보는 이유는 나중에 나오는 Accuracy, Recall, Precision 를 구할때

개 기준의 Accuracy, Recall, Precision , 고양이 기준의 Accuracy, Recall, Precision,

오리 기준의 Accuracy, Recall, Precision, 말 기준의 Accuracy, Recall, Precision 를 각각 구해서

 Accuracy, Recall, Precision 평균을 구해야 되기 때문.)

개를 기준으로 표현하면 "개 가 있다" 를 Positivie "개 의외의 동물이다" 를 Negative로 생각하면 된다.

오리를 기준이면 "오리 가 있다" 를 Positivie "오리 의외의 동물이다" 를 Negative로 생각하면 된다.

 

1번 표

그런데 실제로 다중분류 해야할 종류가 100개 10000개가 넘어가면 표를 10000개를 만들어야 될까?

또한 실제 수많은 경우의 수을 구하면 행렬로 수많은 경우의 수가 구해지므로 1번 표 로 표현하기엔 한계가 있으므로 보통은 아래와 같이 2번 표 형식으로 표현한다.

 

2번 표

표를 간략하게 표현 했는데

 

모델이 "실제 개 사진을 개로 추측한 것"이 9개

         "실제 오리 사진을 고양이로 추측한 것"이 2개

         "실제 오리 사진을 오리로 추측한 것"이 16개

         등등..

 

으로 보면 된다. 이 2번 표로 개 고양이 오리 말 각 기준에서의 TP,FP,TN,FN 쓰면

    개 기준                                                                                      고양이 기준
오리 기준                                                                                         말 기준

 

이렇게 되는데 이해하기 힘들다...

개 고양이 오리 말 4가지 경우의 수로 생각하고 보면 이해 하기 힘드니

앞서 보여준 1번 표처럼

 

'기준 동물'Positivie '기준 의외의 동물'Negative

2가지 경우의 수로 생각하고 봐야 제대로 표를 이해하고 볼수 있다.

 

  • 고양이 기준 표로 예시를 들면

"실제 사진을 모델이 라고 추측" 을

"실제 고양이 의외의 동물(Negative) 사진고양이 의외의 동물(Negative)로 추측" 으로 봐야한다. = TN

 

"실제 오리 사진을 모델이 라고 추측" 또한

"실제 고양이 의외의 동물(Negative) 사진고양이 의외의 동물(Negative)로 추측" 으로 동일 = TN

 

"실제 고양이 사진을 모델이 오리 라고 추측" 은

"실제 고양이(Positivie) 사진을 모델이 고양이 의외의 동물(Negative) 라고 추측" = FN

 

"실제 말 사진을 모델이 고양이 라고 추측" 

"실제 고양이(Positivie) 사진을 모델이 고양이(Positivie) 라고 추측" = TP

 

"실제 말 사진을 모델이 고양이 라고 추측" 

"실제 고양이 의외의 동물(Negative) 사진을 모델이 고양이(Positivie) 라고 추측" = FP

 

 

  • 말 기준 표 예시

"실제 오리 사진을 모델이 고양이 라고 추측" 

"실제 말 의외의 동물(Negative) 사진을 모델이  의외의 동물(Negative) 라고 추측" = TN

 

"실제 말 사진을 모델이 고양이 라고 추측" 

"실제 (Positivie) 사진을 모델이  의외의 동물(Negative) 라고 추측" = FN

 

"실제 개 사진을 모델이 말 이라고 추측" 

"실제  의외의 동물(Negative) 사진을 모델이 (Positivie) 이라고 추측" = FP

 

 

이제 이진분류, 다중분류 TP FP FN TN 을 정확히 구별 할수있다.

드디어 Accuracy, Recall, Precision 를 구할 수 있다.

2. Accuracy, Precision, Recall

암이 있으면 Positivie 암이 없으면 Negative 예시로 설명.

True Positivie(TP) : A에게 실제로 암이 있는데 모델이 암이 있다고 추측한 경우

False Positive(FP) : A에게 실제로 암이 없는데 모델이 암이 있다고 추측한 경우

True Negative(TN) :A에게 실제로 암이 없는데 모델이 암이 없다고 추측한 경우

False Negative(FN) : A에게 실제로  암이 있는데 모델이 암이 없다고 추측한 경우

2-1 Accuracy

 

$$Accuracy = \frac{TP + TN}{TP + TN + FP + FN} = \frac{정답}{모든 데이터}$$

True Positivie(TP) : A에게 실제로 암이 있는데 모델이 암이 있다고 추측한 경우

True Negative(TN) :A에게 실제로 암이 없는데 모델이 암이 없다고 추측한 경우

TP + TN = 실제 값이 모델의 예측 값 과 동일한 경우 = 정답인 경우

TP + TN + FP + FN = 모든 경우의 수 합 = 모든 데이터 = 모든 이미지 데이터

 

Accuracy 흔히들 아는 정확도. 다만 데이터가 불균형할 경우 제대로된 지표가 될수 없음.

예시로 100명인 사람중에 실제로 희귀병이 있는 사람이 1명일 경우

100명다 정상이라고 진단하면 정적인 희귀병 1명을 못찾아도 정확도는 99%가 됨.

2-2 Precision

$$Precision = \frac{TP}{TP + FP} = \frac{실제 암을 모델이 암이라고 제대로 예측한 경우}{모델이 암이라고 예측한 경우}$$

True Positivie(TP) : A에게 실제로 암이 있는데 모델이 암이 있다고 추측한 경우

False Positive(FP) : A에게 실제로 암이 없는데 모델이 암이 있다고 추측한 경우

Predicted condition positive 행 = positive라고 추측 한 행 = TP + FP

TP + FP = 모델이 암이 있다고 추측한 경우의 수

 

Precision 은 실제 값 즉, 암이 실제로 있는지 없는지 는 중요하지 않고 모델이 암이 있다고 추측한 경우의 수에 포커스를 두고 있다. 다시말해 있다고 추측한 경우의 수암이 실제로 있는데 있다고 추측한 경우 이므로 정밀도 라고 볼수있다. 

 

2-3 Recall

$$Recall = \frac{TP}{TP + FN} = \frac{실제 암을 모델이 암이라고 제대로 예측한 경우}{실제 암이 있는 경우}$$

 

True Positivie(TP) : A에게 실제로  암이 는데 모델이 암이 있다고 추측한 경우

False Negative(FN) : A에게 실제로 암이 는데 모델이 암이 없다고 추측한 경우

condition positive 열 = 실제 positive 인 열 = TP + FN

TP + FN = 실제로 암이 는 경우

 

Recall은  실제로 암이는 경우 에 포커스를 두고있다.

모델이 암이 있다고 추측 한 경우에 포커스를 두고있는 Precision 과는 완전 반대.

 

 

https://www.popit.kr/%EC%9A%A9%EC%96%B4-%EC%A0%95%EB%A6%AC-%EC%9E%85%EA%B0%9C%EB%B0%9C%EC%9E%90%EB%A5%BC-%EC%9C%84%ED%95%9C-accuracy-precision-recall/

 

[용어 정리] 입개발자를 위한 Accuracy, Precision, Recall | Popit

머신러닝하면 자주 등장하는 Accuracy, Recall, Precision라는 용어에 대해서 간단하게 살펴 보았습니다.

www.popit.kr

 

https://moons08.github.io/datascience/classification_score_basic/

 

다중 분류 문제 성능평가 [기본편]

어떤 모델, 혹은 방법을 쓰던 분류 문제는 그 의도에 따라 다양한 성능평가 방식을 사용합니다. 사람, 고양이, 개 3개의 클래스를 분류하는 다중 분류(multi label) 예제를 통해 정리해보겠습니다.

moons08.github.io

https://darkpgmr.tistory.com/162

 

precision, recall의 이해

자신이 어떤 기술을 개발하였다. 예를 들어 이미지에서 사람을 자동으로 찾아주는 영상 인식 기술이라고 하자. 이 때, 사람들에게 "이 기술의 검출율은 99.99%입니다"라고 말하면 사람들은 "오우..

darkpgmr.tistory.com

https://hoya012.github.io/blog/Tutorials-of-Object-Detection-Using-Deep-Learning-how-to-measure-performance-of-object-detection/

 

Tutorials of Object Detection using Deep Learning [4] How to measure performance of object detection

Deep Learning을 이용한 Object detection Tutorial - [4] How to measure performance of object detection

hoya012.github.io

 

+ Recent posts