Proxy Tuning

Based on Tuning Language Models by Proxy.

Setup

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import re
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
/mnt/ssd/arjun/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We are going to work with three models:

  1. Meta Llama 3.1 8B, which we used for the Math Word Problems homework
  2. Meta Llama 3.2 1B, which is a much smaller and less capable model. But, it does use exactly the same tokenizer as (1).
  3. A fine-tuned version of (2) that we fine-tuned on the GSM8K training set.
LLAMA3P1_1B_BASE = "/mnt/ssd/arjun/models/llama3p2_1b_base"
LLAMA3P1_8B_BASE = "/mnt/ssd/arjun/models/llama3p1_8b_base"
LLAMA3P1_1B_FINETUNED = "/mnt/ssd/arjun/repos/engineering-llm-systems/notes/proxy_tuning/checkpoint_final"

validation = datasets.load_dataset("nuprl/llm-systems-math-word-problems", split="test")

tokenizer = AutoTokenizer.from_pretrained(
    LLAMA3P1_1B_BASE,
    padding_side="left",
    clean_up_tokenization_spaces=False
)
tokenizer.pad_token = tokenizer.eos_token

llama3p1_8b_base = AutoModelForCausalLM.from_pretrained(
    LLAMA3P1_8B_BASE,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)

llama3p1_1b_base = AutoModelForCausalLM.from_pretrained(
    LLAMA3P1_1B_BASE,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)

llama3p1_1b_finetuned = AutoModelForCausalLM.from_pretrained(
    LLAMA3P1_1B_FINETUNED,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.12it/s]

Zero-Shot Accuracy on Math Word Problems

The code below computes zero-shot accuracy on the Math Word Problems dataset using greedy decoding. It is an adaption of a solution to the first part of the Math Word Problems homework. The only change is that we are using the model loaded with Transformers, instead of querying a model running remotely with the OpenAI API.

def query_model(model, prompts: List[str], **kwargs) -> List[str]:
    """
    Query a model with a list of prompts.
    """
    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, **kwargs)
    outputs_without_prefix = outputs[:,len(inputs.input_ids[0]):]
    return tokenizer.batch_decode(outputs_without_prefix, skip_special_tokens=True)
    
    
# Regex to match a number
NUMBER_RE = re.compile(r"[-+]?\d*\.\d+|\d+")

def prompt_zero_shot(problem: str) -> str:
    return f"Problem: {problem}\nAnswer:"


def extract_zero_shot(completion: str, errors: List[str]) -> Optional[int]:
    try:
        num = re.findall(NUMBER_RE, completion)
        return int(float(num[0]))
    except Exception as e:
        errors.append(completion)
        return -1


def solve_zero_shot(model, problem: str, errors: List[str]) -> Optional[int]:
    resp = query_model(
        model,
        prompt_zero_shot(problem),
        do_sample=False,
        top_p=None,
        temperature=None,
        max_new_tokens=10,
        pad_token_id=tokenizer.eos_token_id
    )
    return extract_zero_shot(resp[0], errors)


def accuracy(model, problems: List[dict]) -> Tuple[float, List[str]]:
    errors = []
    num_correct = 0
    for problem in tqdm(problems):
        prediction = solve_zero_shot(model, problem["question"], errors)
        if prediction == problem["answer"]:
            num_correct += 1
    return num_correct / len(problems), errors
accuracy_8b, errors_8b = accuracy(llama3p1_8b_base, validation)
accuracy_1b, errors_1b = accuracy(llama3p1_1b_base, validation)
accuracy_1b_finetuned, errors_1b_finetuned = accuracy(llama3p1_1b_finetuned, validation)

print(f"Accuracy 8B: {accuracy_8b:.2f}")
print(f"Accuracy 1B: {accuracy_1b:.2f}")
print(f"Accuracy 1B Finetuned: {accuracy_1b_finetuned:.2f}")

  0%|          | 0/50 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 50/50 [00:12<00:00,  4.00it/s]
100%|██████████| 50/50 [00:05<00:00,  8.59it/s]
100%|██████████| 50/50 [00:06<00:00,  8.23it/s]

Accuracy 8B: 0.12
Accuracy 1B: 0.06
Accuracy 1B Finetuned: 0.12
len(errors_8b), len(errors_1b), len(errors_1b_finetuned)
(3, 0, 0)

Proxy Tuning

The key idea is to do three forward passes with (1) the big model, (2) the “expert” fine-tuned model, and (3) the “anti-expert” little model. And compute:

logits = big_logits + expert_logits - anti_expert_logits

This reshapes the output of the big model based on the direction of the fine-tuned model. That is all that proxy_forward does.

def proxy_forward(big_model, little_expert_model, little_anti_expert_model, input_ids):
    attention_mask = input_ids != tokenizer.pad_token_id
    with torch.no_grad():
        big_outputs = big_model(input_ids, attention_mask=attention_mask)
        expert_outputs = little_expert_model(input_ids, attention_mask=attention_mask)
        anti_expert_outputs = little_anti_expert_model(input_ids, attention_mask=attention_mask)
    big_logits = big_outputs.logits
    expert_logits = expert_outputs.logits
    anti_expert_logits = anti_expert_outputs.logits
    return big_logits + expert_logits - anti_expert_logits

def proxy_generate(big_model, little_expert_model, little_anti_expert_model, input_ids, max_new_tokens=10):
    """
    An implementation of greedy decoding that uses proxy_forward.
    """
    input_len = input_ids.shape[1]
    # A vector of booleans that indicate if we have reached eos_token_id for
    # each input in the batch.
    eot_reached = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
    for _ in range(max_new_tokens):
        logits = proxy_forward(big_model, little_expert_model, little_anti_expert_model, input_ids)
        # Greedy decoding uses argmax.
        next_token_id = torch.argmax(F.softmax(logits[:, -1], dim=-1), dim=-1)
        # Check if any next_token_id is eos_token_id. Using logical or means that
        # after eot_reached[i] becomes True, it will remain True for every
        # iteration of this loop.
        eot_reached = eot_reached | (next_token_id == tokenizer.eos_token_id)
        if eot_reached.all():
            break
        # If eot_reached[i] is True, then use eos_token_id instead of next_token_id[i]
        masked_next_token_id = torch.masked_fill(next_token_id, eot_reached, tokenizer.eos_token_id)
        input_ids = torch.cat([input_ids, masked_next_token_id.unsqueeze(0)], dim=-1)
    
    outputs_without_prefix = input_ids[:,input_len:]
    return tokenizer.batch_decode(outputs_without_prefix, skip_special_tokens=True)

def proxy_solve(problem: str, errors: List[str]) -> Optional[int]:
    """
    Solves a math word problem using proxy_generate.
    """
    resp = proxy_generate(
        llama3p1_8b_base,
        llama3p1_1b_finetuned,
        llama3p1_1b_base,
        tokenizer(prompt_zero_shot(problem), return_tensors="pt").input_ids.to("cuda"), 
    )
    
    return extract_zero_shot(resp[0], errors)


def proxy_accuracy(problems: List[dict]):
    errors = []
    num_correct = 0
    for problem in tqdm(problems):
        prediction = proxy_solve(problem["question"], errors)
        if prediction == problem["answer"]:
            num_correct += 1
    return num_correct / len(problems), errors
proxy_accuracy(validation)
100%|██████████| 50/50 [00:24<00:00,  2.07it/s]





(0.14, [])