# Supervised Fine-Tuning

## Introduction

The goal of this tutorial is to show you the basics of training and generation
with a single GPU. The Transformers library has a lot of convenient methods
for training and generation. We are going to avoid them and instead work directly
with the neural network. This is the only way to understand the material.

To use this notebook, you will need:

1. A Python environment with PyTorch installed. I am not going to tell you how
   to set this up, because it varies considerably across machines.
2. A GPU with 16GB+ VRAM.
3. The following extra packages, which you can install with pip:
   ```
   pip3 install transformers datasets matplotlib tqdm flash_attn accelerate
   ```

If you have trouble installing `flash_attn`, visit the 
[Flash Attention 2](https://github.com/Dao-AILab/flash-attention) page.

We first load a couple of modules.

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
import re
import math
from typing import List, Optional
from IPython.display import clear_output


The following cell loads the model and tokenizer. You may need to modify
the `MODEL` variable below to load the model from a different path.
**If you mess up the model during training, consider re-running the cell to
reload the model.**

In [2]:
MODEL = "/mnt/ssd/arjun/models/llama3p2_1b_base"

tokenizer = AutoTokenizer.from_pretrained(MODEL, padding_side="left")
# I don't know why this isn't set by default, but you always want this.
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    use_cache=False,
).to(0)

## Dataset and Metrics

In [3]:
test_dataset = datasets.load_dataset(
    "nuprl/engineering-llm-systems", "math_word_problems", split="test"
)
train_dataset = datasets.load_dataset(
    "nuprl/engineering-llm-systems", "gsm8k", split="train"
)

The function below does batched generation, which will help evaluation run
substantially faster.

In [None]:
def generate_text(
    prompts: List[str],
    max_new_tokens: int = 200,
    temperature: float = 0.2,
    stop: List[str] = [],
) -> List[str]:
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
    with torch.no_grad():   
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            pad_token_id=tokenizer.eos_token_id,
            top_p=None, # Llama 3 sets top_p by default
            do_sample=temperature > 0,
    )

    # Decode the generated text
    generated_texts = tokenizer.batch_decode(
        outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True
    )

    trimmed_texts = []
    for generated_text in generated_texts:
        for stop_seq in stop:
            if stop_seq in generated_text:
                generated_text = generated_text[: generated_text.index(stop_seq)]
        trimmed_texts.append(generated_text.strip())

    return trimmed_texts

print([test_dataset[0]["answer"], test_dataset[1]["answer"]])
generate_text([
    test_dataset[0]["question"] + "\nAnswer:", 
    test_dataset[1]["question"] + "\nAnswer:"
    ], 
    temperature=0.0,
    stop = ["\n"]
)

In [None]:
NUMBER_RE = re.compile(r"[-+]?\d*\.\d+|\d+")

def prompt_zero_shot(problem: str) -> str:
    return f"Problem: {problem}\nAnswer:"

def extract_zero_shot(completion: str) -> Optional[int]:
    try:
        num = re.findall(NUMBER_RE, completion)
        return int(float(num[0]))
    except Exception as e:
        tqdm.write(f"Error processing {completion}")
        return -1


def solve_zero_shot(problem: str) -> Optional[int]:
    resp = generate_text(prompt_zero_shot(problem), temperature=0.2, max_new_tokens=10)
    return extract_zero_shot(resp)


def accuracy(batch_size, problems: List[dict]):
    num_correct = 0
    for batch_start_index in range(0, len(problems), batch_size):
        batch_end_index = min(batch_start_index + batch_size, len(problems))
        batch = problems.select(range(batch_start_index, batch_end_index))
        predictions = generate_text(
            [prompt_zero_shot(problem["question"]) for problem in batch],
            temperature=0,
            max_new_tokens=10,
            stop = ["\n"]
        )
        for problem, prediction in zip(batch, predictions):
            if extract_zero_shot(prediction) == problem["answer"]:
                num_correct += 1
    return num_correct / len(problems)


accuracy(20, test_dataset)

## Model Basics

In this section we will directly use the model to get a sense of what it
does. In ordinary usage, we train it in a loop, or we generate output in 
loop. There will be no loops in this section.

In [None]:
print(train_dataset[0]["question"])
example_inputs = tokenizer([train_dataset[0]["question"]], return_tensors="pt").to(
    model.device
)
print(example_inputs)

The code below runs a *forward pass* with two simplifications:
1. We disable dropout (`model.eval`)
2. We don't compute gradients (`torch.no_grad`)
We'll enable both when we get to training.

Notice that the output from the tokenizer is conveniently structured so that we
can pass it as keyword arguments to `model.forward` by writing
`model.forward(**example_inputs)`. In our case, this is equivalent to:

```python
model.forward(
    input_ids=example_inputs.input_ids,
    attention_mask=example_inputs.attention_mask
)
```

In [None]:
model.eval()
with torch.no_grad():
    example_outputs = model.forward(**example_inputs)
print(example_outputs)

The result above has a lot of optional fields, but the only one that is set is `logits`.
Let's compare its shape to the shape of the input:

In [None]:
print(example_inputs.input_ids.shape)
print(example_outputs.logits.shape)

We have one output per input token, and each output is a tensor with ~128,000
elements. Let's look at one of them.

In [None]:
example_outputs.logits[0, 1]

Each index in this tensor corresponds to a token type ID, and the
output represents the distribution over all possible tokens. But, look at the
numbers. There are plenty of negative numbers, so this is *not* a probability
distribution.

These are raw, unnormalized predictions, or scores, or *logits*. We can turn 
them into a distribution using the *softmax* function. We can also sum them to
verify that they sum to 1:

In [None]:
example_dist_single = F.softmax(example_outputs.logits[0, 1], dim=0)
print(example_dist_single)
print(example_dist_single.sum())

Let's turn every output into a distribution:

In [None]:
# Copied from above
example_inputs = tokenizer([train_dataset[0]["question"]], return_tensors="pt").to(
    model.device
)
model.eval()
with torch.no_grad():
    example_outputs = model.forward(**example_inputs)

# Notice that we do .logits[0] and not .logits[0,0] as above.
example_dist = F.softmax(example_outputs.logits[0], dim=1)
print(example_dist.shape)

We can also do this for a batch of inputs:

In [None]:
# Notice we have set padding=True.
example_inputs = tokenizer([train_dataset[0]["question"], train_dataset[1]["question"]], padding=True, return_tensors="pt").to(
    model.device
)
model.eval()
with torch.no_grad():
    example_outputs = model.forward(**example_inputs)

# Notice that we do .logits and not .logits[0] as above.
example_dist = F.softmax(example_outputs.logits, dim=2)
print(example_dist.shape)

How many of these predictions are correct? The goal is to predict the next token. Let's check one.

In [None]:
torch.argmax(example_dist[1, 10]),  example_inputs.input_ids[1, 11]

We can loop over and check all.

In [None]:
num_correct = 0
num_total = 0
for batch_ix in range(example_dist.shape[0]):
    # We exclude the last token because we don't know what it should predict.
    for tok_ix in range(example_dist.shape[1] - 1):
        # Skip the padding tokens. Should probably skip the BOS token too.
        if example_inputs.input_ids[batch_ix, tok_ix] == tokenizer.eos_token_id:
            continue
        num_total += 1
        if torch.argmax(example_dist[batch_ix, tok_ix]).item() == example_inputs.input_ids[batch_ix, tok_ix + 1].item():
            num_correct += 1
        
print(num_correct / num_total)

A different question we can ask is: what is the probability of the expected token? It should be high, but it may be low.

In [None]:
preds = [ ]
for batch_ix in range(example_dist.shape[0]):
    # We exclude the last token because we don't know what it should predict.
    for tok_ix in range(example_dist.shape[1] - 1):
        # Skip the padding tokens. Should probably skip the BOS token too.
        if example_inputs.input_ids[batch_ix, tok_ix + 1] == tokenizer.eos_token_id:
            continue
        expected_token_id = example_inputs.input_ids[batch_ix, tok_ix + 1].item()
        expected_token_prob = example_dist[batch_ix, tok_ix, expected_token_id].item()
        preds.append(expected_token_prob)
print(sum([math.log(p) for p in preds]) / len(preds))


We can instead do this using the builtin cross entropy loss function. The
predictions we want are from every output position expect the last one. The
labels we want are from all the input positions except the first one.

In [None]:
all_logits = example_outputs.logits[:, :-1, :].reshape(-1, example_outputs.logits.shape[-1])
all_labels = example_inputs.input_ids[:, 1:].reshape(-1)
print(all_logits.shape, all_labels.shape)
F.cross_entropy(all_logits, all_labels,  ignore_index=tokenizer.eos_token_id)

Let's compute the loss immediately after the forward pass and without `@torch.no_grad()`.

In [None]:
model.eval()  # This is wrong, deliberately. Will fix later.
example_inputs = tokenizer([ 
        prompt_zero_shot(train_dataset[0]["question"]) + " " + train_dataset[0]["answer"],
        prompt_zero_shot(train_dataset[1]["question"]) + " " + train_dataset[1]["answer"]
    ], 
    padding=True, 
    return_tensors="pt").to(
    model.device
)
example_outputs = model.forward(**example_inputs)
all_logits = example_outputs.logits[:, :-1, :].reshape(-1, example_outputs.logits.shape[-1])
all_labels = example_inputs.input_ids[:, 1:].reshape(-1)
loss = F.cross_entropy(all_logits, all_labels,  ignore_index=tokenizer.eos_token_id)
print("Loss:", loss)

Notice that in the output above, we have a `grad_fn`, which is allows
backpropogation, which is what `loss.backward` does.

In PyTorch, `backward` computes and saves gradients, but *does not 
update model weights.* You can confirm this by re-running the cell above
repeatedly. You'll get exactly the same loss, which indicates the model is
not learning anything.

To actually update model weights, we need an optimizer. The textbook approach
to optimization is *stochastic gradient descent (SGD)*. We are going to use
*AdamW*, which is more sophisticated and works better with LLMs.

In [18]:
optimizer = AdamW(model.parameters(), lr=1e-5)

The code below uses the optimizer with two new lines at the end:
1. `optimizer.step()` updates weights
2. `optimizer.zero_grad()` creates the gradients that `.backward` computes.
   *Always* call `.zero_grad` immediately after `.step` for now.

The cell below is what we call the *training cell*.

In [None]:
model.eval()
example_inputs = tokenizer([ 
        prompt_zero_shot(train_dataset[0]["question"]) + " " + train_dataset[0]["answer"],
        prompt_zero_shot(train_dataset[1]["question"]) + " " + train_dataset[1]["answer"]
    ], 
    padding=True, 
    return_tensors="pt").to(
    model.device
)
example_outputs = model.forward(**example_inputs)
all_logits = example_outputs.logits[:, :-1, :].reshape(-1, example_outputs.logits.shape[-1])
all_labels = example_inputs.input_ids[:, 1:].reshape(-1)
loss = F.cross_entropy(all_logits, all_labels,  ignore_index=tokenizer.eos_token_id)
print("Loss:", loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()

Run the cell above several times. You will probably see loss going down. Model is learning! You can run the cell below to see the predictions. Once the loss
above goes down significantly, you'll see predictions closer to the input
sequence.

In [None]:
model.eval()  # Do NOT change this
print(train_dataset[0]["question"])
example_inputs = tokenizer([ 
        prompt_zero_shot(train_dataset[0]["question"]) + " " + train_dataset[0]["answer"],
    ], 
    padding=True, 
    return_tensors="pt").to(
    model.device
)
with torch.no_grad():
    example_outputs = model.forward(**example_inputs)
example_dist = F.softmax(example_outputs.logits[0], dim=1)
output_tokens = torch.argmax(example_dist, dim=1)
print("Output tokens:", output_tokens)
for tok in output_tokens.cpu().tolist():
    print(tok, "->", tokenizer.decode(tok).__repr__())

Finally, let's address the use of `model.eval()`.  `model.eval()` disables the 
*dropout* layers. Dropout is essential regularization, but introduces
randomness. Enable it by changing `model.eval()` to `model.train()` in the
**training cell**. You can re-run training.

## Training A Model

Based on the code above, you are ready to write a training loop. The basic idea
is to do this:

1. Loop over the items in the training set. Pick a batch size that works with
   your GPU. With a 1B model and 48GB VRAM, you should be able to do 8 items.
2. Log loss to confirm that the model is learning the training data distribution.
3. Log accuracy on the test set to confirm that the model is generalizing to
   unseen data.

*Note:*: Remember to call `model.train()` before training and `model.eval()`
before testing.

In [21]:
# Your code here.