Proxy Tuning
Based on Tuning Language Models by Proxy.
Setup
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import re
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
/mnt/ssd/arjun/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
We are going to work with three models:
- Meta Llama 3.1 8B, which we used for the Math Word Problems homework
- Meta Llama 3.2 1B, which is a much smaller and less capable model. But, it does use exactly the same tokenizer as (1).
- A fine-tuned version of (2) that we fine-tuned on the GSM8K training set.
LLAMA3P1_1B_BASE = "/mnt/ssd/arjun/models/llama3p2_1b_base"
LLAMA3P1_8B_BASE = "/mnt/ssd/arjun/models/llama3p1_8b_base"
LLAMA3P1_1B_FINETUNED = "/mnt/ssd/arjun/repos/engineering-llm-systems/notes/proxy_tuning/checkpoint_final"
validation = datasets.load_dataset("nuprl/llm-systems-math-word-problems", split="test")
tokenizer = AutoTokenizer.from_pretrained(
LLAMA3P1_1B_BASE,
padding_side="left",
clean_up_tokenization_spaces=False
)
tokenizer.pad_token = tokenizer.eos_token
llama3p1_8b_base = AutoModelForCausalLM.from_pretrained(
LLAMA3P1_8B_BASE,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
llama3p1_1b_base = AutoModelForCausalLM.from_pretrained(
LLAMA3P1_1B_BASE,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
llama3p1_1b_finetuned = AutoModelForCausalLM.from_pretrained(
LLAMA3P1_1B_FINETUNED,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00, 1.12it/s]
Zero-Shot Accuracy on Math Word Problems
The code below computes zero-shot accuracy on the Math Word Problems dataset using greedy decoding. It is an adaption of a solution to the first part of the Math Word Problems homework. The only change is that we are using the model loaded with Transformers, instead of querying a model running remotely with the OpenAI API.
def query_model(model, prompts: List[str], **kwargs) -> List[str]:
"""
Query a model with a list of prompts.
"""
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, **kwargs)
outputs_without_prefix = outputs[:,len(inputs.input_ids[0]):]
return tokenizer.batch_decode(outputs_without_prefix, skip_special_tokens=True)
# Regex to match a number
NUMBER_RE = re.compile(r"[-+]?\d*\.\d+|\d+")
def prompt_zero_shot(problem: str) -> str:
return f"Problem: {problem}\nAnswer:"
def extract_zero_shot(completion: str, errors: List[str]) -> Optional[int]:
try:
num = re.findall(NUMBER_RE, completion)
return int(float(num[0]))
except Exception as e:
errors.append(completion)
return -1
def solve_zero_shot(model, problem: str, errors: List[str]) -> Optional[int]:
resp = query_model(
model,
prompt_zero_shot(problem),
do_sample=False,
top_p=None,
temperature=None,
max_new_tokens=10,
pad_token_id=tokenizer.eos_token_id
)
return extract_zero_shot(resp[0], errors)
def accuracy(model, problems: List[dict]) -> Tuple[float, List[str]]:
errors = []
num_correct = 0
for problem in tqdm(problems):
prediction = solve_zero_shot(model, problem["question"], errors)
if prediction == problem["answer"]:
num_correct += 1
return num_correct / len(problems), errors
accuracy_8b, errors_8b = accuracy(llama3p1_8b_base, validation)
accuracy_1b, errors_1b = accuracy(llama3p1_1b_base, validation)
accuracy_1b_finetuned, errors_1b_finetuned = accuracy(llama3p1_1b_finetuned, validation)
print(f"Accuracy 8B: {accuracy_8b:.2f}")
print(f"Accuracy 1B: {accuracy_1b:.2f}")
print(f"Accuracy 1B Finetuned: {accuracy_1b_finetuned:.2f}")
0%| | 0/50 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 50/50 [00:12<00:00, 4.00it/s]
100%|██████████| 50/50 [00:05<00:00, 8.59it/s]
100%|██████████| 50/50 [00:06<00:00, 8.23it/s]
Accuracy 8B: 0.12
Accuracy 1B: 0.06
Accuracy 1B Finetuned: 0.12
len(errors_8b), len(errors_1b), len(errors_1b_finetuned)
(3, 0, 0)
Proxy Tuning
The key idea is to do three forward passes with (1) the big model, (2) the “expert” fine-tuned model, and (3) the “anti-expert” little model. And compute:
logits = big_logits + expert_logits - anti_expert_logits
This reshapes the output of the big model based on the direction of the fine-tuned model. That is all that proxy_forward
does.
def proxy_forward(big_model, little_expert_model, little_anti_expert_model, input_ids):
attention_mask = input_ids != tokenizer.pad_token_id
with torch.no_grad():
big_outputs = big_model(input_ids, attention_mask=attention_mask)
expert_outputs = little_expert_model(input_ids, attention_mask=attention_mask)
anti_expert_outputs = little_anti_expert_model(input_ids, attention_mask=attention_mask)
big_logits = big_outputs.logits
expert_logits = expert_outputs.logits
anti_expert_logits = anti_expert_outputs.logits
return big_logits + expert_logits - anti_expert_logits
def proxy_generate(big_model, little_expert_model, little_anti_expert_model, input_ids, max_new_tokens=10):
"""
An implementation of greedy decoding that uses proxy_forward.
"""
input_len = input_ids.shape[1]
# A vector of booleans that indicate if we have reached eos_token_id for
# each input in the batch.
eot_reached = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
for _ in range(max_new_tokens):
logits = proxy_forward(big_model, little_expert_model, little_anti_expert_model, input_ids)
# Greedy decoding uses argmax.
next_token_id = torch.argmax(F.softmax(logits[:, -1], dim=-1), dim=-1)
# Check if any next_token_id is eos_token_id. Using logical or means that
# after eot_reached[i] becomes True, it will remain True for every
# iteration of this loop.
eot_reached = eot_reached | (next_token_id == tokenizer.eos_token_id)
if eot_reached.all():
break
# If eot_reached[i] is True, then use eos_token_id instead of next_token_id[i]
masked_next_token_id = torch.masked_fill(next_token_id, eot_reached, tokenizer.eos_token_id)
input_ids = torch.cat([input_ids, masked_next_token_id.unsqueeze(0)], dim=-1)
outputs_without_prefix = input_ids[:,input_len:]
return tokenizer.batch_decode(outputs_without_prefix, skip_special_tokens=True)
def proxy_solve(problem: str, errors: List[str]) -> Optional[int]:
"""
Solves a math word problem using proxy_generate.
"""
resp = proxy_generate(
llama3p1_8b_base,
llama3p1_1b_finetuned,
llama3p1_1b_base,
tokenizer(prompt_zero_shot(problem), return_tensors="pt").input_ids.to("cuda"),
)
return extract_zero_shot(resp[0], errors)
def proxy_accuracy(problems: List[dict]):
errors = []
num_correct = 0
for problem in tqdm(problems):
prediction = proxy_solve(problem["question"], errors)
if prediction == problem["answer"]:
num_correct += 1
return num_correct / len(problems), errors
proxy_accuracy(validation)
100%|██████████| 50/50 [00:24<00:00, 2.07it/s]
(0.14, [])