# Speculative Decoding

Based on [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/abs/2302.01318).

## Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer

We load two models: a small model that is fast, and a big model that is more capable. Both models share the same tokenizer, which allows us easily compare their predictions and use the output of one as the input to the other.

In [None]:
LITTLE_MODEL_PATH = "/mnt/ssd/arjun/models/llama3p2_1b_base"
BIG_MODEL_PATH = "/mnt/ssd/arjun/models/llama3p1_8b_base"

tokenizer = AutoTokenizer.from_pretrained(
    LITTLE_MODEL_PATH,
    padding_side="left",
    clean_up_tokenization_spaces=False
)
tokenizer.pad_token = tokenizer.eos_token

little_model = AutoModelForCausalLM.from_pretrained(
    LITTLE_MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="cuda"
)

big_model = AutoModelForCausalLM.from_pretrained(
    BIG_MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="cuda"
)

## Generate and Check

As a warmup, we will do the following: generate tokens with the little model, and then check that the big model would have generated the same tokens. The trick is that the check with the big model can be done in a single forward pass.

*Warning:*  The code below will not do the right thing when the prompt produces an `<|endoftext|>` token before
the `max_new_tokens` limit. The problem is that it will force the big and little models to agree on the the tokens
that appear after the `<|endoftext|>` token, which is unncessary. Moreover, the token decoding will output the
text after the `<|endoftext|>`, which is wrong.

In [4]:
def generate_and_check(prompts: List[str], max_new_tokens: int = 10) -> Optional[List[str]]:
    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(little_model.device)

    with torch.no_grad():
        little_model_output_tokens = little_model.generate(
            **inputs, 
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            # The rest is to definitely turn off sampling
            do_sample=False,
            top_p=None,
            temperature=None,
        )
        # The final token generated by the little model does not need to be used
        # as input to the big model.
        big_model_logits = big_model.forward(little_model_output_tokens[:,:-1]).logits[:, -max_new_tokens:]
    
    
    little_model_preds = little_model_output_tokens[:, -max_new_tokens:]
    big_model_dist = F.softmax(big_model_logits, dim=-1)
    big_model_preds = big_model_dist.argmax(dim=-1)

    # Check if little and big produce the same predictions for each item in the batch
    are_predictions_same = torch.all(little_model_preds == big_model_preds, dim=1)
    # Check if all items in the batch have the same predictions
    all_same = torch.all(are_predictions_same, dim=0).item()
    if not all_same:
        return None
    return tokenizer.batch_decode(little_model_preds, skip_special_tokens=True, clean_up_tokenization_spaces=False)

For any prompt, we can now check if the two models agree for up to N tokens.

For how many tokens does the little model and big model agree?

In [None]:
def check_agreement(prompt: str, max_new_tokens: int = 10):
    for i in range(1, max_new_tokens):
        result = generate_and_check([prompt], i)
        if result is None:
            return None
        print(i, repr(result[0]))

PROMPT0 = '''
    def factorial(n):
        """
        Returns the factorial of n.
        """
'''.strip()

check_agreement(PROMPT0, 35)

In [None]:
check_agreement("def hello", 35)

## Greedy Speculative Decoding

The basic idea is this: generate a "draft" with the little model. Check if the big model agrees with the draft. If it does not, extend the input with:

1. The prefix of the draft on which the two models agree.
2. The first output token from the big model that is different from the draft.

Note that (1) may be empty if the models disagree on the entire draft. Thus (2) is essential for the algorithm to make progress.

In [7]:
def speculate(input_ids, attention_mask, max_new_tokens: int):
    with torch.no_grad():
        little_model_output_tokens = little_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            # The rest is to definitely turn off sampling
            do_sample=False,
            top_p=None,
            temperature=None,
        )

        # The little model may not generate max_new_tokens if it produces eos_token_id first.
        num_new_tokens = little_model_output_tokens.shape[1] - input_ids.shape[1]

        # The final token generated by the little model does not need to be used
        # as input to the big model.
        big_model_logits = big_model.forward(little_model_output_tokens[:,:-1]).logits[:, -num_new_tokens:]
    
    
    little_model_preds = little_model_output_tokens[:, -num_new_tokens:]
    big_model_dist = F.softmax(big_model_logits, dim=-1)
    big_model_preds = big_model_dist.argmax(dim=-1)

    diff_indices = (little_model_preds != big_model_preds).nonzero(as_tuple=True)[1]
    # If there are no differences, set the index to the length of the sequence
    if diff_indices.numel() == 0:
        return (True, little_model_preds)
    else:
        diff_index = diff_indices[0].item()
        big_model_output_tokens = F.softmax(big_model_logits, dim=-1).argmax(dim=-1)
        valid_prefix = big_model_output_tokens[:, :(diff_index+1)]
        input_with_valid_prefix = torch.cat([input_ids, valid_prefix], dim=1)
        return (False, input_with_valid_prefix)


# NOTE: This won't do the right thing when a prompt produces an eos_token_id before
# the max_new_tokens limit.
def speculative_decoding(prompts: List[str], max_new_tokens: int = 10) -> Tuple[List[str], int]:
    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(little_model.device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    counter = 0
    ok = False
    while not ok and input_ids.shape[1] < max_new_tokens:
        counter += 1
        (ok, generated_ids) = speculate(input_ids, attention_mask, max_new_tokens)
        input_ids = generated_ids
        attention_mask = torch.ones_like(input_ids)
    return (tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False), counter)

In [None]:
(code, counter) = speculative_decoding(["def hello"], 20)
print(f"Forward passes of the big model: {counter}")
print(code[0])


In [None]:
(code, counter) = speculative_decoding(["def shortest_path(graph):\n\t"], 180)
print(f"Forward passes of the big model: {counter}")
print(code[0])


## Speculative Decoding with Sampling

The approach above uses greedy decoding. How can we modify it to sample from the little model, and still preserve the distribution of the big model? See [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/abs/2302.01318) for a solution.