Speculative Decoding
Based on Accelerating Large Language Model Decoding with Speculative Sampling.
Setup
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
/mnt/ssd/arjun/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
We load two models: a small model that is fast, and a big model that is more capable. Both models share the same tokenizer, which allows us easily compare their predictions and use the output of one as the input to the other.
LITTLE_MODEL_PATH = "/mnt/ssd/arjun/models/llama3p2_1b_base"
BIG_MODEL_PATH = "/mnt/ssd/arjun/models/llama3p1_8b_base"
tokenizer = AutoTokenizer.from_pretrained(
LITTLE_MODEL_PATH,
padding_side="left",
clean_up_tokenization_spaces=False
)
tokenizer.pad_token = tokenizer.eos_token
little_model = AutoModelForCausalLM.from_pretrained(
LITTLE_MODEL_PATH,
torch_dtype=torch.float16,
device_map="cuda"
)
big_model = AutoModelForCausalLM.from_pretrained(
BIG_MODEL_PATH,
torch_dtype=torch.float16,
device_map="cuda"
)
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00, 1.10it/s]
Generate and Check
As a warmup, we will do the following: generate tokens with the little model, and then check that the big model would have generated the same tokens. The trick is that the check with the big model can be done in a single forward pass.
Warning: The code below will not do the right thing when the prompt produces an <|endoftext|>
token before
the max_new_tokens
limit. The problem is that it will force the big and little models to agree on the the tokens
that appear after the <|endoftext|>
token, which is unncessary. Moreover, the token decoding will output the
text after the <|endoftext|>
, which is wrong.
def generate_and_check(prompts: List[str], max_new_tokens: int = 10) -> Optional[List[str]]:
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(little_model.device)
with torch.no_grad():
little_model_output_tokens = little_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
# The rest is to definitely turn off sampling
do_sample=False,
top_p=None,
temperature=None,
)
# The final token generated by the little model does not need to be used
# as input to the big model.
big_model_logits = big_model.forward(little_model_output_tokens[:,:-1]).logits[:, -max_new_tokens:]
little_model_preds = little_model_output_tokens[:, -max_new_tokens:]
big_model_dist = F.softmax(big_model_logits, dim=-1)
big_model_preds = big_model_dist.argmax(dim=-1)
# Check if little and big produce the same predictions for each item in the batch
are_predictions_same = torch.all(little_model_preds == big_model_preds, dim=1)
# Check if all items in the batch have the same predictions
all_same = torch.all(are_predictions_same, dim=0).item()
if not all_same:
return None
return tokenizer.batch_decode(little_model_preds, skip_special_tokens=True, clean_up_tokenization_spaces=False)
For any prompt, we can now check if the two models agree for up to N tokens.
For how many tokens does the little model and big model agree?
def check_agreement(prompt: str, max_new_tokens: int = 10):
for i in range(1, max_new_tokens):
result = generate_and_check([prompt], i)
if result is None:
return None
print(i, repr(result[0]))
PROMPT0 = '''
def factorial(n):
"""
Returns the factorial of n.
"""
'''.strip()
check_agreement(PROMPT0, 35)
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
1 ' \n'
2 ' \n '
3 ' \n if'
4 ' \n if n'
5 ' \n if n =='
6 ' \n if n == '
7 ' \n if n == 0'
8 ' \n if n == 0:\n'
9 ' \n if n == 0:\n '
10 ' \n if n == 0:\n return'
11 ' \n if n == 0:\n return '
12 ' \n if n == 0:\n return 1'
13 ' \n if n == 0:\n return 1\n'
14 ' \n if n == 0:\n return 1\n '
15 ' \n if n == 0:\n return 1\n else'
16 ' \n if n == 0:\n return 1\n else:\n'
17 ' \n if n == 0:\n return 1\n else:\n '
18 ' \n if n == 0:\n return 1\n else:\n return'
19 ' \n if n == 0:\n return 1\n else:\n return n'
20 ' \n if n == 0:\n return 1\n else:\n return n *'
21 ' \n if n == 0:\n return 1\n else:\n return n * factorial'
22 ' \n if n == 0:\n return 1\n else:\n return n * factorial(n'
23 ' \n if n == 0:\n return 1\n else:\n return n * factorial(n-'
24 ' \n if n == 0:\n return 1\n else:\n return n * factorial(n-1'
25 ' \n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n'
26 ' \n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n'
check_agreement("def hello", 35)
1 '_world'
2 '_world():\n'
3 '_world():\n '
4 '_world():\n """'
5 '_world():\n """Print'
6 '_world():\n """Prints'
Greedy Speculative Decoding
The basic idea is this: generate a “draft” with the little model. Check if the big model agrees with the draft. If it does not, extend the input with:
- The prefix of the draft on which the two models agree.
- The first output token from the big model that is different from the draft.
Note that (1) may be empty if the models disagree on the entire draft. Thus (2) is essential for the algorithm to make progress.
def speculate(input_ids, attention_mask, max_new_tokens: int):
with torch.no_grad():
little_model_output_tokens = little_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
# The rest is to definitely turn off sampling
do_sample=False,
top_p=None,
temperature=None,
)
# The little model may not generate max_new_tokens if it produces eos_token_id first.
num_new_tokens = little_model_output_tokens.shape[1] - input_ids.shape[1]
# The final token generated by the little model does not need to be used
# as input to the big model.
big_model_logits = big_model.forward(little_model_output_tokens[:,:-1]).logits[:, -num_new_tokens:]
little_model_preds = little_model_output_tokens[:, -num_new_tokens:]
big_model_dist = F.softmax(big_model_logits, dim=-1)
big_model_preds = big_model_dist.argmax(dim=-1)
diff_indices = (little_model_preds != big_model_preds).nonzero(as_tuple=True)[1]
# If there are no differences, set the index to the length of the sequence
if diff_indices.numel() == 0:
return (True, little_model_preds)
else:
diff_index = diff_indices[0].item()
big_model_output_tokens = F.softmax(big_model_logits, dim=-1).argmax(dim=-1)
valid_prefix = big_model_output_tokens[:, :(diff_index+1)]
input_with_valid_prefix = torch.cat([input_ids, valid_prefix], dim=1)
return (False, input_with_valid_prefix)
# NOTE: This won't do the right thing when a prompt produces an eos_token_id before
# the max_new_tokens limit.
def speculative_decoding(prompts: List[str], max_new_tokens: int = 10) -> Tuple[List[str], int]:
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(little_model.device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
counter = 0
ok = False
while not ok and input_ids.shape[1] < max_new_tokens:
counter += 1
(ok, generated_ids) = speculate(input_ids, attention_mask, max_new_tokens)
input_ids = generated_ids
attention_mask = torch.ones_like(input_ids)
return (tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False), counter)
(code, counter) = speculative_decoding(["def hello"], 20)
print(f"Forward passes of the big model: {counter}")
print(code[0])
Forward passes of the big model: 2
Hello World!" to the console."""
print("Hello World!")
(code, counter) = speculative_decoding(["def shortest_path(graph):\n\t"], 180)
print(f"Forward passes of the big model: {counter}")
print(code[0])
Forward passes of the big model: 15
def shortest_path(graph):
"""
Find the shortest path between two nodes in a graph.
"""
# Initialize variables
start = None
end = None
path = []
visited = set()
# Find the start and end nodes
for node in graph:
if node == start:
start = node
elif node == end:
end = node
# Initialize the queue
queue = [start]
# Loop until the queue is empty
while queue:
# Get the current node
current = queue.pop(0)
# Check if the current node is the end node
if current == end:
break
# Add the current node to the visited set
visited.add(current)
# Get the neighbors of the current node
neighbors = graph[current]
# Loop through the neighbors
for neighbor in neighbors:
# Check if the neighbor is
Speculative Decoding with Sampling
The approach above uses greedy decoding. How can we modify it to sample from the little model, and still preserve the distribution of the big model? See Accelerating Large Language Model Decoding with Speculative Sampling for a solution.