import torch
import datasets
import textwrap
from collections import namedtuple
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL = "/mnt/ssd/arjun/models/llama3p1_8b_base"
model = AutoModelForCausalLM.from_pretrained(MODEL).to("cuda:0")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
def llama3(prompt, temperature=0, max_new_tokens=20, stop=None):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generate_kwargs = {
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens,
}
if temperature > 0:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
else:
generate_kwargs["do_sample"] = False
generate_kwargs["temperature"] = None
with torch.no_grad():
# Llama 3.1 uses top-p sampling by default, which we disable here.
example_outputs = model.generate(**inputs, **generate_kwargs, top_p=None)
outputs = tokenizer.decode(example_outputs[0, inputs["input_ids"].shape[1]:])
if stop is not None:
outputs = outputs.split(stop)[0]
return outputs
llama3("Shakespeare was", temperature=0.2)
' a genius. He was a genius at writing plays, and he was a genius at writing poetry.'
Loading Datasets
bbh = datasets.load_dataset("maveriq/bigbenchhard", "reasoning_about_colored_objects", split="train")
def inspect_bbh(item):
txt, options = item["input"].split("Options:", maxsplit=1)
txt = textwrap.fill(txt, width=80)
for opt in options.split("\n"):
if item["target"] in opt:
txt += f"\n\nAnswer: {opt}"
break
return txt
print(inspect_bbh(bbh[0]))
On the floor, there is one mauve cat toy, two purple cat toys, three grey cat
toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and
one purple notebook. If I remove all the notebooks from the floor, how many grey
objects remain on it?
Answer: (D) three
print(bbh[0]["input"])
On the floor, there is one mauve cat toy, two purple cat toys, three grey cat toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and one purple notebook. If I remove all the notebooks from the floor, how many grey objects remain on it?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
(H) seven
(I) eight
(J) nine
(K) ten
(L) eleven
(M) twelve
(N) thirteen
(O) fourteen
(P) fifteen
(Q) sixteen
print(bbh[1]["input"])
On the desk, you see a set of things arranged in a row: a grey cup, a purple mug, and a blue teddy bear. What is the color of the thing directly to the right of the cup?
Options:
(A) red
(B) orange
(C) yellow
(D) green
(E) blue
(F) brown
(G) magenta
(H) fuchsia
(I) mauve
(J) teal
(K) turquoise
(L) burgundy
(M) silver
(N) gold
(O) black
(P) grey
(Q) purple
(R) pink
Zero-shot
llama3(bbh[0]["input"])
'\n(R) seventeen\n(S) eighteen\n(T) nineteen\n(U) twenty\n(V) twenty'
print(llama3(bbh[0]["input"] + "\nThe answer is:", max_new_tokens=100))
(A) zero
Explanation:
There are 3 grey cat toys, 3 grey notebooks, and 3 burgundy cat toys. If we remove all the notebooks, we are left with 3 grey cat toys and 3 burgundy cat toys.<|end_of_text|>
print(llama3(bbh[1]["input"] + "\nThe answer is:", max_new_tokens=100))
(Q) purple
Explanation: The cup is grey, so the thing directly to the right of the cup is the purple mug.<|end_of_text|>
def bbh_zero_shot(problem, show_result):
prompt = problem["input"] + "\nAnswer:"
completion = llama3(prompt, temperature=0, max_new_tokens=5, stop="\n")
expected = problem["target"]
if show_result:
print(inspect_bbh(problem))
print(f"Result: {completion}")
return completion.strip() == expected
zero_shot_failures = [ ]
for problem in tqdm(bbh):
if bbh_zero_shot(problem, False) == False:
zero_shot_failures.append(problem)
print(f"Zero-shot failures: {len(zero_shot_failures)}")
0%| | 0/250 [00:00<?, ?it/s]
Zero-shot failures: 249
zero_shot_failures[100]
{'input': 'On the desk, you see several things arranged in a row: a burgundy bracelet, a grey mug, a green necklace, and a magenta textbook. What is the color of the thing directly to the left of the necklace?\nOptions:\n(A) red\n(B) orange\n(C) yellow\n(D) green\n(E) blue\n(F) brown\n(G) magenta\n(H) fuchsia\n(I) mauve\n(J) teal\n(K) turquoise\n(L) burgundy\n(M) silver\n(N) gold\n(O) black\n(P) grey\n(Q) purple\n(R) pink',
'target': '(P)'}
llama3(zero_shot_failures[100]["input"] + "\nAnswer:")
' P\nExplanation: To find the answer, start at the necklace and move left. The necklace is'
Few-Shot
FEW_SHOT_PROMPT = bbh[0]["input"] + "\n\nCorrect option: " + bbh[0]["target"] + "\n\n" + bbh[1]["input"] + "\n\nCorrect option: " + bbh[1]["target"] + "\n\n" + bbh[2]["input"] + "\n\nCorrect option: " + bbh[2]["target"] + "\n\n"
def prompt_fewshot_shot(item):
return FEW_SHOT_PROMPT + "\n\n" + item["input"] + "\n\nCorrect option:"
def extract_fewshot_shot(response):
return response.strip()
def solve_fewshot_shot(item):
response = extract_fewshot_shot(llama3(prompt_fewshot_shot(item), temperature=0, max_new_tokens=3))
if item["target"] == response:
return True
else:
print(f"Expected {item['target']} got {response}")
return False
solve_fewshot_shot(bbh[2])
True
def accuracy_fewshot_shot(items):
num_correct = 0
failures = [ ]
for item in tqdm(items):
result = solve_fewshot_shot(item)
if result:
num_correct += 1
else:
failures.append(item)
return (num_correct / len(items), failures)
accuracy, failures = accuracy_fewshot_shot(bbh)
0%| | 0/250 [00:00<?, ?it/s]
Expected (D) got (F)
Expected (B) got (I)
Expected (F) got (G)
Expected (I) got (M)
Expected (D) got (A)
Expected (C) got (D)
Expected (I) got (F)
Expected (E) got (D)
Expected (D) got (E)
Expected (K) got (J)
Expected (A) got (D)
Expected (B) got (C)
Expected (B) got (F)
Expected (B) got (C)
Expected (I) got (A)
Expected (H) got (R)
Expected (R) got (D)
Expected (O) got (A)
Expected (F) got (E)
Expected (D) got (L)
Expected (N) got (M)
Expected (D) got (A)
Expected (B) got (A)
Expected (D) got (M)
Expected (F) got (D)
Expected (H) got (D)
Expected (B) got (D)
Expected (A) got (D)
Expected (D) got (E)
Expected (Q) got (M)
Expected (E) got (D)
Expected (A) got (C)
Expected (O) got (K)
Expected (G) got (D)
Expected (A) got (C)
Expected (B) got (D)
Expected (B) got (M)
Expected (G) got (M)
Expected (M) got (G)
Expected (R) got (A)
Expected (B) got (C)
Expected (C) got (A)
Expected (A) got (D)
Expected (O) got (B)
Expected (A) got (M)
Expected (I) got (M)
Expected (F) got (C)
Expected (D) got (E)
Expected (O) got (M)
Expected (G) got (D)
Expected (B) got (D)
Expected (P) got (L)
Expected (A) got (F)
Expected (F) got (M)
Expected (A) got (C)
Expected (R) got (A)
Expected (F) got (G)
Expected (E) got (P)
Expected (B) got (A)
Expected (D) got (C)
Expected (A) got (D)
Expected (A) got (G)
Expected (P) got (H)
Expected (G) got (D)
Expected (G) got (M)
Expected (A) got (B)
Expected (R) got (K)
Expected (B) got (E)
Expected (B) got (A)
Expected (A) got (F)
Expected (A) got (B)
Expected (B) got (C)
Expected (C) got (E)
Expected (D) got (C)
Expected (B) got (A)
Expected (C) got (D)
Expected (F) got (E)
Expected (G) got (D)
Expected (F) got (D)
Expected (A) got (B)
Expected (C) got (D)
Expected (F) got (D)
Expected (R) got (J)
Expected (D) got (F)
Expected (A) got (G)
Expected (H) got (M)
Expected (A) got (R)
Expected (A) got (G)
Expected (E) got (C)
Expected (B) got (C)
Expected (E) got (D)
Expected (B) got (F)
Expected (R) got (A)
Expected (E) got (A)
Expected (P) got (L)
Expected (O) got (L)
Expected (I) got (Q)
Expected (K) got (L)
Expected (N) got (M)
Expected (B) got (D)
Expected (I) got (H)
Expected (E) got (D)
Expected (A) got (L)
Expected (H) got (M)
Expected (F) got (C)
Expected (B) got (E)
Expected (N) got (L)
Expected (B) got (G)
Expected (P) got (G)
Expected (A) got (D)
Expected (B) got (F)
Expected (A) got (B)
Expected (F) got (E)
Expected (A) got (M)
Expected (O) got (M)
Expected (B) got (A)
Expected (P) got (G)
Expected (D) got (E)
Expected (I) got (D)
Expected (A) got (E)
Expected (J) got (D)
Expected (B) got (A)
Expected (B) got (F)
Expected (E) got (F)
Expected (R) got (A)
Expected (A) got (B)
Expected (C) got (D)
Expected (N) got (M)
Expected (A) got (E)
Expected (B) got (K)
Expected (E) got (L)
Expected (A) got (D)
accuracy
0.472
Chain-of-Thought
print(bbh[0]["target"])
(D)
print(bbh[2]["input"])
On the nightstand, you see a set of items arranged in a row: a gold plate, a silver stress ball, a fuchsia notebook, a mauve bracelet, a green jug, and a yellow fidget spinner. What is the color of the item directly to the left of the jug?
Options:
(A) red
(B) orange
(C) yellow
(D) green
(E) blue
(F) brown
(G) magenta
(H) fuchsia
(I) mauve
(J) teal
(K) turquoise
(L) burgundy
(M) silver
(N) gold
(O) black
(P) grey
(Q) purple
(R) pink
"""
"""
COT_PROMPT = """Question: On the floor, there is one mauve cat toy, two purple cat toys, three grey cat toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and one purple notebook. If I remove all the notebooks from the floor, how many grey objects remain on it?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
(H) seven
(I) eight
(J) nine
(K) ten
(L) eleven
(M) twelve
(N) thirteen
(O) fourteen
(P) fifteen
(Q) sixteen
Reasoning: We are looking for grey objects, so we have three grey cat toys and three grey notebooks. We remove three grey notebooks. So we have three remaining grey objects. Three is associated with the letter (D).
Answer: (D)
Question: On the nightstand, you see a set of items arranged in a row: a gold plate, a silver stress ball, a fuchsia notebook, a mauve bracelet, a green jug, and a yellow fidget spinner. What is the color of the item directly to the left of the jug?
Options:
(A) red
(B) orange
(C) yellow
(D) green
(E) blue
(F) brown
(G) magenta
(H) fuchsia
(I) mauve
(J) teal
(K) turquoise
(L) burgundy
(M) silver
(N) gold
(O) black
(P) grey
(Q) purple
(R) pink
Reasoning: Mauve is the color of the bracelet and the bracelet is to the left of the jug. Mauve is associated with the letter I.
Answer: (I)
"""
COTAnswer = namedtuple("COTAnswer", ["problem", "reasoning", "answer", "ok"])
def bbh_cot(problem):
_input = problem["input"]
prompt = f"{COT_PROMPT}\n\nQuestion:{_input}\nReasoning:"
completion = llama3(prompt, temperature=0, max_new_tokens=100, stop="\nQuestion:")
reasoning, answer = completion.split("\nAnswer:", maxsplit=1)
expected = problem["target"]
return COTAnswer(inspect_bbh(problem), reasoning.strip(), answer.strip(), answer.strip() == expected)
result = bbh_cot(zero_shot_failures[6])
cot_failures = [ ]
cot_successes = [ ]
for problem in tqdm(zero_shot_failures):
r = bbh_cot(problem)
if r.ok:
cot_successes.append(r)
else:
cot_failures.append(r)
0%| | 0/249 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[37], line 4
2 cot_successes = [ ]
3 for problem in tqdm(zero_shot_failures):
----> 4 r = bbh_cot(problem)
5 if r.ok:
6 cot_successes.append(r)
Cell In[35], line 7, in bbh_cot(problem)
5 prompt = f"{COT_PROMPT}\n\nQuestion:{_input}\nReasoning:"
6 completion = llama3(prompt, temperature=0, max_new_tokens=100, stop="\nQuestion:")
----> 7 reasoning, answer = completion.split("\nAnswer:", maxsplit=1)
8 expected = problem["target"]
9 return COTAnswer(inspect_bbh(problem), reasoning.strip(), answer.strip(), answer.strip() == expected)
ValueError: not enough values to unpack (expected 2, got 1)
len(cot_successes)
26
print(cot_failures[0].problem)
On the desk, you see a set of things arranged in a row: a grey cup, a purple
mug, and a blue teddy bear. What is the color of the thing directly to the right
of the cup?
Answer: (Q) purple
print(cot_failures[0].reasoning)
Grey is the color of the cup and the cup is to the right of the mug. Grey is associated with the letter P.
COT v2
COT_PROMPT = """
Question: On the floor, there is one mauve cat toy, two purple cat toys, three grey cat toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and one purple notebook. If I remove all the notebooks from the floor, how many grey objects remain on it?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
(H) seven
(I) eight
(J) nine
(K) ten
(L) eleven
(M) twelve
(N) thirteen
(O) fourteen
(P) fifteen
(Q) sixteen
Reasoning: The total number of gray objects is three cat toys and three notebooks, which is six grey objects. If I remove all the notebooks, the three grey cat toys remain.
Answer: (D)
Question: On the table, you see several items arranged in a row: a fuchsia teddy bear, a red jug, a gold puzzle, a burgundy bracelet, and a green notebook. How many non-magenta items do you see to the right of the red item?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
Reasoning: The jug is the red item. The items to its right that are non-magenta are the
puzzle, the bracelet, and the notebook. There are 3 non-megenta items to
the right.
Answer: (D)""".strip()
cot_failures = [ ]
cot_successes = [ ]
for problem in tqdm(zero_shot_failures):
r = bbh_cot(problem)
if r.ok:
cot_successes.append(r)
else:
cot_failures.append(r)
0%| | 0/249 [00:00<?, ?it/s]
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[44], line 46
44 cot_successes = [ ]
45 for problem in tqdm(zero_shot_failures):
---> 46 r = bbh_cot(problem)
47 if r.ok:
48 cot_successes.append(r)
Cell In[35], line 6, in bbh_cot(problem)
4 _input = problem["input"]
5 prompt = f"{COT_PROMPT}\n\nQuestion:{_input}\nReasoning:"
----> 6 completion = llama3(prompt, temperature=0, max_new_tokens=100, stop="\nQuestion:")
7 reasoning, answer = completion.split("\nAnswer:", maxsplit=1)
8 expected = problem["target"]
Cell In[3], line 15, in llama3(prompt, temperature, max_new_tokens, stop)
12 generate_kwargs["temperature"] = None
13 with torch.no_grad():
14 # Llama 3.1 uses top-p sampling by default, which we disable here.
---> 15 example_outputs = model.generate(**inputs, **generate_kwargs, top_p=None)
16 outputs = tokenizer.decode(example_outputs[0, inputs["input_ids"].shape[1]:])
17 if stop is not None:
File ~/venv_ml/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/venv_ml/lib/python3.10/site-packages/transformers/generation/utils.py:2215, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2207 input_ids, model_kwargs = self._expand_inputs_for_generation(
2208 input_ids=input_ids,
2209 expand_size=generation_config.num_return_sequences,
2210 is_encoder_decoder=self.config.is_encoder_decoder,
2211 **model_kwargs,
2212 )
2214 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2215 result = self._sample(
2216 input_ids,
2217 logits_processor=prepared_logits_processor,
2218 stopping_criteria=prepared_stopping_criteria,
2219 generation_config=generation_config,
2220 synced_gpus=synced_gpus,
2221 streamer=streamer,
2222 **model_kwargs,
2223 )
2225 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2226 # 11. prepare beam search scorer
2227 beam_scorer = BeamSearchScorer(
2228 batch_size=batch_size,
2229 num_beams=generation_config.num_beams,
(...)
2234 max_length=generation_config.max_length,
2235 )
File ~/venv_ml/lib/python3.10/site-packages/transformers/generation/utils.py:3206, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3203 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
3205 # forward pass to get next token
-> 3206 outputs = self(**model_inputs, return_dict=True)
3208 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3209 model_kwargs = self._update_model_kwargs_for_generation(
3210 outputs,
3211 model_kwargs,
3212 is_encoder_decoder=self.config.is_encoder_decoder,
3213 )
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1190, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
1187 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1189 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1190 outputs = self.model(
1191 input_ids=input_ids,
1192 attention_mask=attention_mask,
1193 position_ids=position_ids,
1194 past_key_values=past_key_values,
1195 inputs_embeds=inputs_embeds,
1196 use_cache=use_cache,
1197 output_attentions=output_attentions,
1198 output_hidden_states=output_hidden_states,
1199 return_dict=return_dict,
1200 cache_position=cache_position,
1201 )
1203 hidden_states = outputs[0]
1204 if self.config.pretraining_tp > 1:
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:945, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
933 layer_outputs = self._gradient_checkpointing_func(
934 decoder_layer.__call__,
935 hidden_states,
(...)
942 position_embeddings,
943 )
944 else:
--> 945 layer_outputs = decoder_layer(
946 hidden_states,
947 attention_mask=causal_mask,
948 position_ids=position_ids,
949 past_key_value=past_key_values,
950 output_attentions=output_attentions,
951 use_cache=use_cache,
952 cache_position=cache_position,
953 position_embeddings=position_embeddings,
954 )
956 hidden_states = layer_outputs[0]
958 if use_cache:
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:676, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
673 hidden_states = self.input_layernorm(hidden_states)
675 # Self Attention
--> 676 hidden_states, self_attn_weights, present_key_value = self.self_attn(
677 hidden_states=hidden_states,
678 attention_mask=attention_mask,
679 position_ids=position_ids,
680 past_key_value=past_key_value,
681 output_attentions=output_attentions,
682 use_cache=use_cache,
683 cache_position=cache_position,
684 position_embeddings=position_embeddings,
685 **kwargs,
686 )
687 hidden_states = residual + hidden_states
689 # Fully Connected
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:577, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
575 else:
576 cos, sin = position_embeddings
--> 577 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
579 if past_key_value is not None:
580 # sin and cos are specific to RoPE models; cache_position needed for the static cache
581 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
KeyboardInterrupt:
len(cot_successes) / (len(cot_failures) + len(cot_successes))
0.6268656716417911