SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification
This paper introduces SpecInfer, a system that accelerates generative large language model (LLM) serving with tree-based speculative inference and verification. The key idea behind SpecInfer is leveraging small speculative models to predict the LLM's outpu
arxiv.org
이 글은 전에 리뷰한 Fast Inference from Transformers via Speculative Decoding 논문을 읽고 오시는 걸 추천합니다.
이 글 작성 시점에는 이 논문을 한글로 리뷰하거나 설명한 글이 없기에 다른 한국인 분들이 이 논문을 읽을 때 편하게 읽으시라고 최초로 한글로 정리해서 올립니다. 다른 개인적인 목적을 위해서도 올립니다. 아무쪼록 논문 읽는데 도움 되시길 바랍니다.
Abstract
이 논문은 기존의 방식, 근사 모델 = Approximation Model = $M_{q}$ = SSM(Small speculative model)을 이용해 하나의 토큰 시퀀스만 생성했던 것과 다르게 top-k 방식을 이용해 토큰 트리를 생성하고 더 나아가 여러 개의 SSM을 이용해 여러 개의 토큰 트리를 병합하는 방식으로 토큰 트리를 생성해 이것을 검증하는 방식인 Tree-based Speculative Inference and Verification (트리기반 추측 추론 및 검증)을 소개하는 논문입니다. 이 글에서 근사 모델을 지칭하는 단어는 이 논문과 똑같이 SSM 으로 통일하겠습니다.
1. Introduction
이 챕터에서는 이전의 방식과 이 논문의 Tree-based 방식을 비교하고있습니다.
Incremental Decoding (증분적, 서서히 증가하는 디코딩) 은 기존의 LLM 이 토큰을 생성하는 방식입니다. 이 방식은 토큰을 하나 생성할 때마다 Autoregressive(자기 회귀) 방식으로 생성하기 때문에 최적화되지 않은 런타임 성능 및 GPU 활용 때문에 컴퓨팅 자원이 남습니다.
Sequence-based Speculative Inference는 이 전 논문에서 사용했던 방식인 기존의 하나의 시퀀스만 고려한 Speculative Inference 방식입니다. SSM(Small Speculative Model)과 LLM(Large Language Model) 간의 모델 용량 차이로
인해 두 모델이 생성한 토큰 시퀀스가 잘 정렬(align) 되지 않는다.라고 말하고 있습니다.
여기서 "정렬되지 않는다" 라는 소리는 SSM은 메모리와 계산 요구를 낮추기 위해 LLM보다 훨씬 작은 모델입니다. 그러나, 이로 인해 SSM이 생성한 토큰 시퀀스와 LLM이 생성한 토큰 시퀀스 간의 일치율이 낮다는 뜻입니다. 다시 말해 SSM이 생성한 단일 토큰 시퀀스가 LLM과 잘 맞지 않는 경우, SSM이 추측한 결과는 LLM의 실제 결과와 크게 다를 수 있습니다. 이로 인해 추론 성능이 저하된다는 뜻입니다.
Tree-based Speculative Inference 가 이 논문에서 새롭게 제안하는 트리기반의 Speculative Inference 방식입니다. 다양한 추측 후보들을 고려하기에 이러한 정렬문제를 해결합니다.
Tree-based Speculative Inference를 구현하기 위해 두 가지 도전과제가 있습니다.
- SpecInfer는 매우 큰 검색 공간($502724^4 ≈ 6 × 10^{18}$)을 탐색해야 합니다.
- SpecInfer는 Incremental Decoding과 동일한 확률 분포를 사용하여 Stochastic decoding ( 확률적 디코딩 )에서 추론된 토큰을 검증해야 합니다.
1번 에서의 502724는 논문에서 예시로든 OPT 모델군의 어휘 50272개 를 뜻 합니다. 모델이 처리할 수 있는 고유한 토큰(단어 또는 기호)의 총 수를 의미합니다.
2번은 다시 말해 기존이 Incremental Decoding에서 사용은 확률적 디코딩 방식 (ex top-k) 방식으로 내뱉는 확률 분포와 동일한 분포를 가지면서 토큰 트리를 검증해야 된다는 뜻입니다.
확률적 디코딩 방식 은 ChatGPT 등 생성형 언어모델을 제공하는 서비스들은 같은 프롬프트가 들어와도 Greedy decoding을 써 같은 답변이 계속 나오는 것이 아니라 Stochastic decoding (확률적 디코딩)을 써서 답변이 조금씩 랜덤 하게 달라지는 것이 예시로 들 수 있겠습니다.
2. SpecInfor’s Overview
SpecInfor의 전체적인 알고리즘입니다. 처음 input으로 "machine" 토큰이 들어오면 Expansion-based 방식이나 Merge-based 방식 중 하나의 방식으로 Token Tree를 생성합니다. 이 Token Tree를 Tree-based Parallel Decoding을 이용해 병렬적으로 각 노드의 다음 토큰의 분포를 얻습니다. 각각의 분포를 이용해 트리의 각 노드를 검증합니다.
3. Learning-based Speculator - How to generate a token tree
논문에서는 3번째 챕터 제목을 Learning-based Speculator로 해놨습니다. 이 챕터에서는 토큰 트리를 어떻게 생성하는지 설명하고 있습니다.토큰 트리를 생성하는 방식은 Expansion-based 방식과 Merge-based 방식, 2가지가 있습니다. 그전에 정의를 하고 넘어가겠습니다.
$N$ 은 토큰트리이며 $u$ 는 토큰트리에 속한 노드 들입니다. $t_u$는 $u$ 노드가 표현하는 토큰입니다. $P_u$는 $u$ 노드의 부모노드를 가리킵니다. $S_u$ 는 $u$ 노드 시점에 토큰과 그전에 선택된 모든 토큰의 합이고 이는 하나의 토큰 시퀀스를 나타냅니다.
Expansion-based token tree construction
각 step 마다 top-k를 쓰는 방식입니다. 위의 그림 예제는 step 0, step 1에서는 top-2를 썼고 step 2에서는 top-1을 써 토큰 트리를 생성했습니다. 이를 논문에서는 간결하게 Expansion configuration <2,2,1>이라고 표기하고 있습니다. 그림 예제의 알고리즘을 순서대로 설명하겠습니다.
1. "machine" 토큰을 SSM모델에 input으로 넣고, output으로 생성된 분포에서 가장 나올 확률이 높은 토큰 두 개 즉, top-2 인 "learning", "translation" 토큰을 생성
2. "learning", "translation"을 각각 SSM모델 input으로 넣어 top-2 토큰을 생성
3. top-1 이기에 가장 확률이 높은 토큰 하나만 생성
k 높을수록 후보가 많이 생성되니 수락률이 높아집니다만 그만큼 컴퓨팅 리소스도 많이 잡아먹습니다.
Merge-based token tree construction
Merge-based 방식은 위의 Expansion-based 방식을 병렬적으로 여러 개 실행시켜 여러 개의 토큰 트리를 생성해 병합하는 방식입니다. 각 SSM 은 각각 토큰 트리를 내뱉습니다. 다만 가중치가 같은 SSM 모델로 여러개의 토큰트리를 생성한다 한들 같은 분포를 생성하기에 다양성이 부족합니다. 따라서 저자는 adaptive boosting 방식으로 재학습시킵니다. 알고리즘은 다음과 같습니다.
- 말뭉치(corpus)에서 프롬프트 샘플을 생성합니다.
- 프롬프트 샘플을 LLM input으로 사용하여 토큰 시퀀스를 생성합니다.
- SSM 0을 미세 조정(fine-tune)합니다. 이때 SSM 0과 LLM이 동일한 후속 토큰을 생성한 모든 프롬프트 샘플을 표시합니다.
- 표시되지 않은 프롬프트 샘플을 사용하여 SSM 1을 미세 조정(fine-tune)합니다.
- 이 과정을 모든 SSM에 대해 반복합니다.
이를 통해 분포의 다양성을 확보합니다. 학습 이후 추론과정에서 각각의 SSM 들은 병렬적으로 실행되어 그림의 예시로는 3개의 토큰 트리가 생성됩니다. 이 3개의 토큰트리는 다음과 같은 정의로 병합됩니다.
$\mathcal{M}$은 모든 토큰 트리의 병합 트리입니다. 모든 토큰 트리의 노드의 토큰 시퀀스는 $\mathcal{M}$에 존재하는 노드의 토큰 시퀀스와 같아야 하고 그 역도 성립해야 합니다.
직관적으로, 각 토큰 트리는 토큰 시퀀스 집합을 나타냅니다. 여러 토큰 트리를 병합하면 원래의 모든 트리의 토큰 시퀀스를 포함하는 새로운 트리가 생성된다는 뜻입니다. 여러 개의 트리가 주어졌을 때 우리가 흔히 생각할 수 있는 병합 방식입니다.
4. Token Tree Verifier
토큰 트리를 생성하였다면 이를 검증해야 합니다. 이 챕터에서는 그것을 설명합니다.
4.1 Tree Attention
검증을 하려면 토큰 트리의 모든 토큰 시퀀스를 LLM에 input으로 넣고 분포를 얻어야 합니다. 다시 말해 "모든 토큰 시퀀스는 LLM의 self-attention 구조를 통과해야 된다는 뜻"이고 이는 "LLM에서 모든 토큰 시퀀스의 attention output을 구해야 한다는 뜻"과 같습니다. 이를 위해 Tree Attention을 정의합니다.
Tree Attention 은 토큰 트리 $N$ 의 모든 노드들의 토큰 시퀀스들의 attention입니다.
4.2 Tree-based Parallel Decoding
이 챕터에서는 LLM에서 Tree Attention을 계산하는 2가지 방식을 소개합니다.
여기 예시의 토큰 트리가 주어지고 Sequence-based Parallel Decoding와 Tree-based Parallel Decoding으로 Tree Attention를 구하는 그림이 있습니다.
Sequence-based Parallel Decoding
가장 기본적으로 생각할 수 있는 방식입니다. 예시의 토큰 트리는 리프노드 3개 즉, 자식 노드가 없는 t5, t7 t9 노드의 토큰 시퀀스를 3개의 kernel을 두고 Autoregressive 하게 처리합니다. 이는 t2, t3 토큰시점의 KV-cache 가 중복되는 것을 볼 수 있고 이는 연산의 중복이라고도 볼 수 있습니다.
Tree-based Parallel Decoding
따라서 저자는 Depth-first search to update key-value cache, Topology-aware causal mask을 제안하고 이를 이용한 방식인 Tree-based Parallel Decoding을 제안합니다.
Depth-first search to update key-value cache는 흔히 알고 있는 깊이 우선 탐색의 순서로 연산해 중복 cache를 피합니다. Figure 4의 토큰 트리의 번호 순서가 깊이 우선 탐색 방식으로 순서를 매긴 것 과 같습니다. 이 순서로 하나의 Kernal에서Autoregressive 하게 연산을 처리하면 중복된 KV-cache 및 연산이 없습니다.
Topology-aware causal mask는 그림만 봐도 대충 이해하실 텐데 $t_6$, $t_7$ 토큰은 예시 토큰 트리에서 보면 $t_5$ 와 다른 분기이기에 기존의 causal mask의 역할과 비슷하게 다른 분기의 attention weight에 마이너스 무한대를 곱해줍니다. 이는 트리 기반 attention 계산을 하나의 Kernal에서 수행하면서 인과 관계를 유지하게 해 줍니다.
이 챕터를 읽을때 저는 병렬적으로 처리하는 부분의 설명이 정확히 되어있지 않아 여러 의문점이 드는 부분이 많았습니다. 논문 설명에 관련된 내용이 아니니 읽지 않으셔도 됩니다.
의문점 1. Figure 4를 보면 Sequence-based Parallel Decoding 은 연산 순서가 화살표로 표시되어 있고 Tree-based Parallel Decoding는 컴마로 나눠져 있다. 병렬처리의 다름을 설명하는 것인가?
의문점 2. 병렬적으로 처리한다고 해도 여러 방식이 존재할 텐데 정확히 어떠한 방식으로 병렬적으로 알고리즘이 돌아가는가?
이는 논문에서도 쓰여있지 않습니다. 따라서 저는 3개의 병렬처리 알고리즘을 예상했고 이는 다음과 같습니다.

1번은 모든 토큰을 한 번에 병렬적으로 연산합니다. 이러면 KV-cache를 활용을 못하며 연산도 중복됩니다. 추측하건대 이는 Fast Inference from Transformers via Speculative Decoding 논문의 병렬처리 방식 같습니다. SpecInfer 논문에서는 KV-cache 활용을 계속 언급하고 있기에 이런 방식을 고려하지 않는 것 같습니다. 따라서 제외.
2번은 Tree-based Parallel Decoding에서 각 토큰을 컴마 표기로 나눠져 있기에 "이렇게 연산한다는 건가?" 가정이 들어 제안했습니다. 다만 1번과 동일하게 이러면 KV-cache 활용을 못합니다. 또한 t2 나 t3 토큰 시점 노드의 후보 토큰 분포는 어떻게 얻고 검증하는지 의문입니다. 모델 output 구조를 바꾸면 이러한 input 이 가능해 보입니다. 어찌 보면 가장 병렬적으로 연산중복 없이 연산하는 거겠지요. 이는 새로운 연구 방향성 일수도 있겠습니다. 이 논문에서 이러한 input을 쓴다는 것 인지 확실치 않지만 일단 이 논문은 KV-cache 활용한다는 점은 명확하기에 이것도 제외.
3번이 가장 유력해 보입니다. KV-cache 도 활용가능하며 Depth-first search to update key-value cache, Topology-aware causal mask 두 개를 써 Kernal 1개로 통합하면 중복된 연산과 KV-cache 도 없습니다. 다만 "KV-cache을 활용하기 위해서 Autoregressive 하게 처리한다는 소리인데 이것을 병렬 처리라고 말하고 Parallel Decoding이라고 부를 수 있는 것인가?"라는 큰 의문점이 있습니다.
의문점 3. 3번의 방식으로 병렬처리를 한다고 치고 Kernal 하나로 통합해 연산한다고 가정한다면 또다시 두 가지 방식을 생각해 볼 수 있습니다.

1번은 말 그대로 그냥 순서대로 Autoregressive 하게 처리합니다.
2번은 Autoregressive 하게 처리함으로써 KV-cache 도 활용하되 연산 중복이 안 되는 선에서 병렬적으로 처리합니다.
이러한 의문점들을 해결하고 정확한 병렬처리 알고리즘이 궁금하신 분은 오픈소스로 코드가 공개되어 있으니 그 코드를 분석해야 할 것 같습니다.
4.3 Token Verification
효율적으로 Tree Attention을 구했다면 우리는 최종적으로 각각의 노드들을 검증할 수 있는 분포를 를 얻습니다. 예시 1 사진으로 설명하자면 $u_1$ 에서 "<start>" 토큰을 LLM input으로 넣었을 때 다음에 어떤 토큰이 나올지 후보 토큰들을 표현하는 확률 분포를 얻게 되겠죠. "The" 토큰이 확률이 높을 수도 "The", "She"가 아닌 "He"라는 토큰이 나올 확률이 90% 일수도 있겠습니다. $u_1$ 말고도 $u_2$, $u_3$ ... 등등 각각 노드의 LLM 확률 분포를 얻게 됩니다. 이를 가지고 저자는 두 가지 검증 방식을 소개합니다.
VerifyGreedy (탐욕적 검증)
Greedy 방식으로 검증하는 방식은 간단하게 LLM 확률 분포에서의 top-1 토큰이 자식 노드에 있으면 수락하고 없으면 LLM 확률 분포에서의 top-1 토큰을 수락해 리턴합니다.
그림으로 예시를 들자면 $u_1$ 노드에서 LLM이 다음 토큰을 "The"가 가장 확률이 높게 나올 토큰으로 예측하면 "The" 토큰을 수락하고 자식노드 검증 반복 합니다.
"The", "She"가 아닌 "He"가 가장 확률이 높게 나올 토큰으로 예측하면 "He" 토큰을 수락하고 검증 종료 후 SSM을 이용해 토큰 트리를 다시 생성합니다.
VerifyStochastic (확률적 검증)
VerifyStochastic 함수는 Multi-step Speculative Sampling 즉, 각각의 노드를 for 문으로 돌면서 Speculative Sampling을 하는 알고리즘과 같습니다.
이는 Fast Inference from Transformers via Speculative Decoding 논문을 읽고 Speculative Sampling 알고리즘을 아셔야 이해할 수 있습니다. (Fast Inference from Transformers via Speculative Decoding 글의 2. Speculative Decoding 부분 참고)
Fast Inference from Transformers via Speculative Decoding 논문의 Speculative Decoding 알고리즘 중 Speculative Sampling 부분과 유사함을 그림으로 표현했습니다.
Speculative Decoding의 Speculative Sampling 은 거부 될 때 집합에 포함되므로 $ r_i > \frac{p(x)}{q(x)} $ 이고 VerifyStochastic의 Speculative Sampling은 수락 될때를 if 문으로 표현했기에 $ r_i \leq \frac{p(x)}{q(x)} $ 입니다.
즉, 거부될때를 표현한 식, 수락할때 표현한 식 이므로 부호가 반대되어있습니다. 또한 $ r_i \leq \frac{p(x)}{q(x)} $ 와 $ r_i > 1 - \frac{p(x)}{q(x)} $ 는 확률적으로 선택하는 식으로써 확률이 같습니다. (Fast Inference from Transformers via Speculative Decoding 글의 2. Speculative Decoding 부분 참고)
결론적으로 표현한 수식만 다르고 본질적인 알고리즘은 동일합니다.
여기서 중요한 점은 Multi-step으로 Speculative Sampling을 하기 때문에 Speculative Sampling과 동일하게 검증할 때 분포가 바뀌지 않고 유지됩니다. 이는 Introduction에서 언급한 2번째 도전과제를 Multi-step Speculative Sampling으로 해결한 것 이라고도 볼수있습니다.
5. System Design and Implementation
이 챕터에서는 SpecInfer의 4가지 특이사항을 소개합니다.
- SpecInfer는 Megatron-LM에서 도입된 하이브리드 병렬화를 사용하여 LLM을 제공합니다.
- SpecInfer는 Orca에서 도입된 연속 배치(Continuous Batching)를 사용합니다.
- SpecInfer는 FlexFlow 위에 구현되었습니다.
- SpecInfer는 FasterTransformer를 기반으로 한 맞춤형 CUDA 커널을 사용합니다.
그리고 몇 가지 SpecInfer을 사용함으로써 발생되는 Overhead를 말하며 반박합니다.
메모리 오버헤드
- 하나 이상의 SSM(Small Speculative Model)의 매개변수를 저장하기 위해 메모리를 할당해야 합니다. - SSM은 전체 메모리 요구량을 1% 미만으로 증가시키기에 무시할 수 있습니다.
- LLM으로 단일 토큰을 디코딩하는 대신 토큰 트리를 검증하는 데 사용됩니다. - 매우 긴 시퀀스 길이에 대한 키-값 캐시와 비교할 때, 토큰 트리의 메모리 오버헤드는 무시할 수 있습니다.
계산 오버헤드
- SSM을 Incremental Decoding 모드로 실행해야 하며 전체 토큰 트리에 대해 Tree Attention 출력을 계산해야 합니다. - Incremental Decoding에서 LLM(Large Language Model)을 제공할 때 GPU의 계산 자원이 남으므로 계산 오버헤드는 무시할 수 있습니다.
6. Evaluation
성능 비교표입니다. 분산 컴퓨팅 비교를 보시면 GPU와 node 가 증가해 병렬성이 증가할수록 SpecInfer의 효율성이 극대화됩니다. Expansion configuration 은 ⟨1,1,3,1,1,1,1,1⟩입니다.
7. RelatedWork, 8. Conclusion 은 별다른 내용이 없기에 넘어가겠습니다. A Artifact Appendix 부분은 흥미로운데 관심 있으신 분들은 읽어보세요.
'AI > 논문' 카테고리의 다른 글
[논문 리뷰] Fast Inference from Transformers via Speculative Decoding (0) | 2025.02.13 |
---|