Speculative Decoding

Based on Accelerating Large Language Model Decoding with Speculative Sampling.

Setup

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
/mnt/ssd/arjun/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We load two models: a small model that is fast, and a big model that is more capable. Both models share the same tokenizer, which allows us easily compare their predictions and use the output of one as the input to the other.

LITTLE_MODEL_PATH = "/mnt/ssd/arjun/models/llama3p2_1b_base"
BIG_MODEL_PATH = "/mnt/ssd/arjun/models/llama3p1_8b_base"

tokenizer = AutoTokenizer.from_pretrained(
    LITTLE_MODEL_PATH,
    padding_side="left",
    clean_up_tokenization_spaces=False
)
tokenizer.pad_token = tokenizer.eos_token

little_model = AutoModelForCausalLM.from_pretrained(
    LITTLE_MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="cuda"
)

big_model = AutoModelForCausalLM.from_pretrained(
    BIG_MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="cuda"
)
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.10it/s]

Generate and Check

As a warmup, we will do the following: generate tokens with the little model, and then check that the big model would have generated the same tokens. The trick is that the check with the big model can be done in a single forward pass.

Warning: The code below will not do the right thing when the prompt produces an <|endoftext|> token before the max_new_tokens limit. The problem is that it will force the big and little models to agree on the the tokens that appear after the <|endoftext|> token, which is unncessary. Moreover, the token decoding will output the text after the <|endoftext|>, which is wrong.

def generate_and_check(prompts: List[str], max_new_tokens: int = 10) -> Optional[List[str]]:
    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(little_model.device)

    with torch.no_grad():
        little_model_output_tokens = little_model.generate(
            **inputs, 
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            # The rest is to definitely turn off sampling
            do_sample=False,
            top_p=None,
            temperature=None,
        )
        # The final token generated by the little model does not need to be used
        # as input to the big model.
        big_model_logits = big_model.forward(little_model_output_tokens[:,:-1]).logits[:, -max_new_tokens:]
    
    
    little_model_preds = little_model_output_tokens[:, -max_new_tokens:]
    big_model_dist = F.softmax(big_model_logits, dim=-1)
    big_model_preds = big_model_dist.argmax(dim=-1)

    # Check if little and big produce the same predictions for each item in the batch
    are_predictions_same = torch.all(little_model_preds == big_model_preds, dim=1)
    # Check if all items in the batch have the same predictions
    all_same = torch.all(are_predictions_same, dim=0).item()
    if not all_same:
        return None
    return tokenizer.batch_decode(little_model_preds, skip_special_tokens=True, clean_up_tokenization_spaces=False)

For any prompt, we can now check if the two models agree for up to N tokens.

For how many tokens does the little model and big model agree?

def check_agreement(prompt: str, max_new_tokens: int = 10):
    for i in range(1, max_new_tokens):
        result = generate_and_check([prompt], i)
        if result is None:
            return None
        print(i, repr(result[0]))

PROMPT0 = '''
    def factorial(n):
        """
        Returns the factorial of n.
        """
'''.strip()

check_agreement(PROMPT0, 35)
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


1 ' \n'
2 ' \n       '
3 ' \n        if'
4 ' \n        if n'
5 ' \n        if n =='
6 ' \n        if n == '
7 ' \n        if n == 0'
8 ' \n        if n == 0:\n'
9 ' \n        if n == 0:\n           '
10 ' \n        if n == 0:\n            return'
11 ' \n        if n == 0:\n            return '
12 ' \n        if n == 0:\n            return 1'
13 ' \n        if n == 0:\n            return 1\n'
14 ' \n        if n == 0:\n            return 1\n       '
15 ' \n        if n == 0:\n            return 1\n        else'
16 ' \n        if n == 0:\n            return 1\n        else:\n'
17 ' \n        if n == 0:\n            return 1\n        else:\n           '
18 ' \n        if n == 0:\n            return 1\n        else:\n            return'
19 ' \n        if n == 0:\n            return 1\n        else:\n            return n'
20 ' \n        if n == 0:\n            return 1\n        else:\n            return n *'
21 ' \n        if n == 0:\n            return 1\n        else:\n            return n * factorial'
22 ' \n        if n == 0:\n            return 1\n        else:\n            return n * factorial(n'
23 ' \n        if n == 0:\n            return 1\n        else:\n            return n * factorial(n-'
24 ' \n        if n == 0:\n            return 1\n        else:\n            return n * factorial(n-1'
25 ' \n        if n == 0:\n            return 1\n        else:\n            return n * factorial(n-1)\n'
26 ' \n        if n == 0:\n            return 1\n        else:\n            return n * factorial(n-1)\n'
check_agreement("def hello", 35)
1 '_world'
2 '_world():\n'
3 '_world():\n   '
4 '_world():\n    """'
5 '_world():\n    """Print'
6 '_world():\n    """Prints'

Greedy Speculative Decoding

The basic idea is this: generate a “draft” with the little model. Check if the big model agrees with the draft. If it does not, extend the input with:

  1. The prefix of the draft on which the two models agree.
  2. The first output token from the big model that is different from the draft.

Note that (1) may be empty if the models disagree on the entire draft. Thus (2) is essential for the algorithm to make progress.

def speculate(input_ids, attention_mask, max_new_tokens: int):
    with torch.no_grad():
        little_model_output_tokens = little_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            # The rest is to definitely turn off sampling
            do_sample=False,
            top_p=None,
            temperature=None,
        )

        # The little model may not generate max_new_tokens if it produces eos_token_id first.
        num_new_tokens = little_model_output_tokens.shape[1] - input_ids.shape[1]

        # The final token generated by the little model does not need to be used
        # as input to the big model.
        big_model_logits = big_model.forward(little_model_output_tokens[:,:-1]).logits[:, -num_new_tokens:]
    
    
    little_model_preds = little_model_output_tokens[:, -num_new_tokens:]
    big_model_dist = F.softmax(big_model_logits, dim=-1)
    big_model_preds = big_model_dist.argmax(dim=-1)

    diff_indices = (little_model_preds != big_model_preds).nonzero(as_tuple=True)[1]
    # If there are no differences, set the index to the length of the sequence
    if diff_indices.numel() == 0:
        return (True, little_model_preds)
    else:
        diff_index = diff_indices[0].item()
        big_model_output_tokens = F.softmax(big_model_logits, dim=-1).argmax(dim=-1)
        valid_prefix = big_model_output_tokens[:, :(diff_index+1)]
        input_with_valid_prefix = torch.cat([input_ids, valid_prefix], dim=1)
        return (False, input_with_valid_prefix)


# NOTE: This won't do the right thing when a prompt produces an eos_token_id before
# the max_new_tokens limit.
def speculative_decoding(prompts: List[str], max_new_tokens: int = 10) -> Tuple[List[str], int]:
    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(little_model.device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    counter = 0
    ok = False
    while not ok and input_ids.shape[1] < max_new_tokens:
        counter += 1
        (ok, generated_ids) = speculate(input_ids, attention_mask, max_new_tokens)
        input_ids = generated_ids
        attention_mask = torch.ones_like(input_ids)
    return (tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False), counter)
(code, counter) = speculative_decoding(["def hello"], 20)
print(f"Forward passes of the big model: {counter}")
print(code[0])

Forward passes of the big model: 2
Hello World!" to the console."""
    print("Hello World!")
(code, counter) = speculative_decoding(["def shortest_path(graph):\n\t"], 180)
print(f"Forward passes of the big model: {counter}")
print(code[0])

Forward passes of the big model: 15
def shortest_path(graph):
	"""
	Find the shortest path between two nodes in a graph.
	"""
	# Initialize variables
	start = None
	end = None
	path = []
	visited = set()

	# Find the start and end nodes
	for node in graph:
		if node == start:
			start = node
		elif node == end:
			end = node

	# Initialize the queue
	queue = [start]

	# Loop until the queue is empty
	while queue:
		# Get the current node
		current = queue.pop(0)

		# Check if the current node is the end node
		if current == end:
			break

		# Add the current node to the visited set
		visited.add(current)

		# Get the neighbors of the current node
		neighbors = graph[current]

		# Loop through the neighbors
		for neighbor in neighbors:
			# Check if the neighbor is

Speculative Decoding with Sampling

The approach above uses greedy decoding. How can we modify it to sample from the little model, and still preserve the distribution of the big model? See Accelerating Large Language Model Decoding with Speculative Sampling for a solution.