import torch
import datasets
import textwrap
from collections import namedtuple
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL = "/mnt/ssd/arjun/models/llama3p1_8b_base"
model = AutoModelForCausalLM.from_pretrained(MODEL).to("cuda:0")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
def llama3(prompt, temperature=0, max_new_tokens=20, stop=None):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    generate_kwargs = {
        "pad_token_id": tokenizer.eos_token_id,
        "max_new_tokens": max_new_tokens,
    }
    if temperature > 0:
        generate_kwargs["do_sample"] = True
        generate_kwargs["temperature"] = temperature
    else:
        generate_kwargs["do_sample"] = False
        generate_kwargs["temperature"] = None
    with torch.no_grad():
        # Llama 3.1 uses top-p sampling by default, which we disable here.
        example_outputs = model.generate(**inputs, **generate_kwargs, top_p=None)
    outputs = tokenizer.decode(example_outputs[0, inputs["input_ids"].shape[1]:])
    if stop is not None:
        outputs = outputs.split(stop)[0]
    return outputs


llama3("Shakespeare was", temperature=0.2)
' a genius. He was a genius at writing plays, and he was a genius at writing poetry.'

Loading Datasets

bbh = datasets.load_dataset("maveriq/bigbenchhard", "reasoning_about_colored_objects", split="train")
def inspect_bbh(item):
    txt, options = item["input"].split("Options:", maxsplit=1)
    txt = textwrap.fill(txt, width=80)
    for opt in options.split("\n"):
        if item["target"] in opt:
            txt += f"\n\nAnswer: {opt}"
            break
    return txt

print(inspect_bbh(bbh[0]))
On the floor, there is one mauve cat toy, two purple cat toys, three grey cat
toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and
one purple notebook. If I remove all the notebooks from the floor, how many grey
objects remain on it?

Answer: (D) three
print(bbh[0]["input"])
On the floor, there is one mauve cat toy, two purple cat toys, three grey cat toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and one purple notebook. If I remove all the notebooks from the floor, how many grey objects remain on it?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
(H) seven
(I) eight
(J) nine
(K) ten
(L) eleven
(M) twelve
(N) thirteen
(O) fourteen
(P) fifteen
(Q) sixteen
print(bbh[1]["input"])
On the desk, you see a set of things arranged in a row: a grey cup, a purple mug, and a blue teddy bear. What is the color of the thing directly to the right of the cup?
Options:
(A) red
(B) orange
(C) yellow
(D) green
(E) blue
(F) brown
(G) magenta
(H) fuchsia
(I) mauve
(J) teal
(K) turquoise
(L) burgundy
(M) silver
(N) gold
(O) black
(P) grey
(Q) purple
(R) pink

Zero-shot

llama3(bbh[0]["input"])
'\n(R) seventeen\n(S) eighteen\n(T) nineteen\n(U) twenty\n(V) twenty'
print(llama3(bbh[0]["input"] + "\nThe answer is:", max_new_tokens=100))
 (A) zero
Explanation:
There are 3 grey cat toys, 3 grey notebooks, and 3 burgundy cat toys. If we remove all the notebooks, we are left with 3 grey cat toys and 3 burgundy cat toys.<|end_of_text|>
print(llama3(bbh[1]["input"] + "\nThe answer is:", max_new_tokens=100))
 (Q) purple
Explanation: The cup is grey, so the thing directly to the right of the cup is the purple mug.<|end_of_text|>
def bbh_zero_shot(problem, show_result):
    prompt = problem["input"] + "\nAnswer:"
    completion = llama3(prompt, temperature=0, max_new_tokens=5, stop="\n")
    expected = problem["target"]
    if show_result:
        print(inspect_bbh(problem))
        print(f"Result: {completion}")
    return completion.strip() == expected

zero_shot_failures = [ ]
for problem in tqdm(bbh):
    if bbh_zero_shot(problem, False) == False:
        zero_shot_failures.append(problem)
print(f"Zero-shot failures: {len(zero_shot_failures)}")
  0%|          | 0/250 [00:00<?, ?it/s]


Zero-shot failures: 249
zero_shot_failures[100]
{'input': 'On the desk, you see several things arranged in a row: a burgundy bracelet, a grey mug, a green necklace, and a magenta textbook. What is the color of the thing directly to the left of the necklace?\nOptions:\n(A) red\n(B) orange\n(C) yellow\n(D) green\n(E) blue\n(F) brown\n(G) magenta\n(H) fuchsia\n(I) mauve\n(J) teal\n(K) turquoise\n(L) burgundy\n(M) silver\n(N) gold\n(O) black\n(P) grey\n(Q) purple\n(R) pink',
 'target': '(P)'}
llama3(zero_shot_failures[100]["input"] + "\nAnswer:")
' P\nExplanation: To find the answer, start at the necklace and move left. The necklace is'

Few-Shot

FEW_SHOT_PROMPT = bbh[0]["input"] + "\n\nCorrect option: " + bbh[0]["target"] + "\n\n" +  bbh[1]["input"] + "\n\nCorrect option: " + bbh[1]["target"] + "\n\n" +  bbh[2]["input"] + "\n\nCorrect option: " + bbh[2]["target"] + "\n\n"

def prompt_fewshot_shot(item):
    return FEW_SHOT_PROMPT + "\n\n" + item["input"] + "\n\nCorrect option:"

def extract_fewshot_shot(response):
    return response.strip()

def solve_fewshot_shot(item):
    response = extract_fewshot_shot(llama3(prompt_fewshot_shot(item), temperature=0, max_new_tokens=3))
    if item["target"] == response:
        return True
    else:
        print(f"Expected {item['target']} got {response}")
        return False

solve_fewshot_shot(bbh[2])
True
def accuracy_fewshot_shot(items):
    num_correct = 0
    failures = [ ]
    for item in tqdm(items):
        result = solve_fewshot_shot(item)
        if result:
            num_correct += 1
        else:
            failures.append(item)

    return (num_correct / len(items), failures)


accuracy, failures = accuracy_fewshot_shot(bbh)
  0%|          | 0/250 [00:00<?, ?it/s]


Expected (D) got (F)
Expected (B) got (I)
Expected (F) got (G)
Expected (I) got (M)
Expected (D) got (A)
Expected (C) got (D)
Expected (I) got (F)
Expected (E) got (D)
Expected (D) got (E)
Expected (K) got (J)
Expected (A) got (D)
Expected (B) got (C)
Expected (B) got (F)
Expected (B) got (C)
Expected (I) got (A)
Expected (H) got (R)
Expected (R) got (D)
Expected (O) got (A)
Expected (F) got (E)
Expected (D) got (L)
Expected (N) got (M)
Expected (D) got (A)
Expected (B) got (A)
Expected (D) got (M)
Expected (F) got (D)
Expected (H) got (D)
Expected (B) got (D)
Expected (A) got (D)
Expected (D) got (E)
Expected (Q) got (M)
Expected (E) got (D)
Expected (A) got (C)
Expected (O) got (K)
Expected (G) got (D)
Expected (A) got (C)
Expected (B) got (D)
Expected (B) got (M)
Expected (G) got (M)
Expected (M) got (G)
Expected (R) got (A)
Expected (B) got (C)
Expected (C) got (A)
Expected (A) got (D)
Expected (O) got (B)
Expected (A) got (M)
Expected (I) got (M)
Expected (F) got (C)
Expected (D) got (E)
Expected (O) got (M)
Expected (G) got (D)
Expected (B) got (D)
Expected (P) got (L)
Expected (A) got (F)
Expected (F) got (M)
Expected (A) got (C)
Expected (R) got (A)
Expected (F) got (G)
Expected (E) got (P)
Expected (B) got (A)
Expected (D) got (C)
Expected (A) got (D)
Expected (A) got (G)
Expected (P) got (H)
Expected (G) got (D)
Expected (G) got (M)
Expected (A) got (B)
Expected (R) got (K)
Expected (B) got (E)
Expected (B) got (A)
Expected (A) got (F)
Expected (A) got (B)
Expected (B) got (C)
Expected (C) got (E)
Expected (D) got (C)
Expected (B) got (A)
Expected (C) got (D)
Expected (F) got (E)
Expected (G) got (D)
Expected (F) got (D)
Expected (A) got (B)
Expected (C) got (D)
Expected (F) got (D)
Expected (R) got (J)
Expected (D) got (F)
Expected (A) got (G)
Expected (H) got (M)
Expected (A) got (R)
Expected (A) got (G)
Expected (E) got (C)
Expected (B) got (C)
Expected (E) got (D)
Expected (B) got (F)
Expected (R) got (A)
Expected (E) got (A)
Expected (P) got (L)
Expected (O) got (L)
Expected (I) got (Q)
Expected (K) got (L)
Expected (N) got (M)
Expected (B) got (D)
Expected (I) got (H)
Expected (E) got (D)
Expected (A) got (L)
Expected (H) got (M)
Expected (F) got (C)
Expected (B) got (E)
Expected (N) got (L)
Expected (B) got (G)
Expected (P) got (G)
Expected (A) got (D)
Expected (B) got (F)
Expected (A) got (B)
Expected (F) got (E)
Expected (A) got (M)
Expected (O) got (M)
Expected (B) got (A)
Expected (P) got (G)
Expected (D) got (E)
Expected (I) got (D)
Expected (A) got (E)
Expected (J) got (D)
Expected (B) got (A)
Expected (B) got (F)
Expected (E) got (F)
Expected (R) got (A)
Expected (A) got (B)
Expected (C) got (D)
Expected (N) got (M)
Expected (A) got (E)
Expected (B) got (K)
Expected (E) got (L)
Expected (A) got (D)
accuracy
0.472

Chain-of-Thought

print(bbh[0]["target"])
(D)
print(bbh[2]["input"])
On the nightstand, you see a set of items arranged in a row: a gold plate, a silver stress ball, a fuchsia notebook, a mauve bracelet, a green jug, and a yellow fidget spinner. What is the color of the item directly to the left of the jug?
Options:
(A) red
(B) orange
(C) yellow
(D) green
(E) blue
(F) brown
(G) magenta
(H) fuchsia
(I) mauve
(J) teal
(K) turquoise
(L) burgundy
(M) silver
(N) gold
(O) black
(P) grey
(Q) purple
(R) pink
"""
"""
COT_PROMPT = """Question: On the floor, there is one mauve cat toy, two purple cat toys, three grey cat toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and one purple notebook. If I remove all the notebooks from the floor, how many grey objects remain on it?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
(H) seven
(I) eight
(J) nine
(K) ten
(L) eleven
(M) twelve
(N) thirteen
(O) fourteen
(P) fifteen
(Q) sixteen

Reasoning: We are looking for grey objects, so we have three grey cat toys and three grey notebooks. We remove three grey notebooks. So we have three remaining grey objects. Three is associated with the letter (D).

Answer: (D)

Question: On the nightstand, you see a set of items arranged in a row: a gold plate, a silver stress ball, a fuchsia notebook, a mauve bracelet, a green jug, and a yellow fidget spinner. What is the color of the item directly to the left of the jug?
Options:
(A) red
(B) orange
(C) yellow
(D) green
(E) blue
(F) brown
(G) magenta
(H) fuchsia
(I) mauve
(J) teal
(K) turquoise
(L) burgundy
(M) silver
(N) gold
(O) black
(P) grey
(Q) purple
(R) pink

Reasoning: Mauve is the color of the bracelet and the bracelet is to the left of the jug. Mauve is associated with the letter I.

Answer: (I)
"""
COTAnswer = namedtuple("COTAnswer", ["problem", "reasoning", "answer", "ok"])

def bbh_cot(problem):
    _input = problem["input"]
    prompt = f"{COT_PROMPT}\n\nQuestion:{_input}\nReasoning:"
    completion = llama3(prompt, temperature=0, max_new_tokens=100, stop="\nQuestion:")
    reasoning, answer = completion.split("\nAnswer:", maxsplit=1)
    expected = problem["target"]
    return COTAnswer(inspect_bbh(problem), reasoning.strip(), answer.strip(), answer.strip() == expected)

result = bbh_cot(zero_shot_failures[6])
cot_failures = [ ]
cot_successes = [ ]
for problem in tqdm(zero_shot_failures):
    r = bbh_cot(problem)
    if r.ok:
        cot_successes.append(r)
    else:
        cot_failures.append(r)
  0%|          | 0/249 [00:00<?, ?it/s]



---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

Cell In[37], line 4
      2 cot_successes = [ ]
      3 for problem in tqdm(zero_shot_failures):
----> 4     r = bbh_cot(problem)
      5     if r.ok:
      6         cot_successes.append(r)


Cell In[35], line 7, in bbh_cot(problem)
      5 prompt = f"{COT_PROMPT}\n\nQuestion:{_input}\nReasoning:"
      6 completion = llama3(prompt, temperature=0, max_new_tokens=100, stop="\nQuestion:")
----> 7 reasoning, answer = completion.split("\nAnswer:", maxsplit=1)
      8 expected = problem["target"]
      9 return COTAnswer(inspect_bbh(problem), reasoning.strip(), answer.strip(), answer.strip() == expected)


ValueError: not enough values to unpack (expected 2, got 1)
len(cot_successes) 
26
print(cot_failures[0].problem)
On the desk, you see a set of things arranged in a row: a grey cup, a purple
mug, and a blue teddy bear. What is the color of the thing directly to the right
of the cup?

Answer: (Q) purple
print(cot_failures[0].reasoning)
Grey is the color of the cup and the cup is to the right of the mug. Grey is associated with the letter P.

COT v2

COT_PROMPT = """
Question: On the floor, there is one mauve cat toy, two purple cat toys, three grey cat toys, two mauve notebooks, three grey notebooks, three burgundy cat toys, and one purple notebook. If I remove all the notebooks from the floor, how many grey objects remain on it?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six
(H) seven
(I) eight
(J) nine
(K) ten
(L) eleven
(M) twelve
(N) thirteen
(O) fourteen
(P) fifteen
(Q) sixteen

Reasoning: The total number of gray objects is three cat toys and three notebooks, which is six grey objects. If I remove all the notebooks,  the three grey cat toys remain.

Answer: (D)

Question: On the table, you see several items arranged in a row: a fuchsia teddy bear, a red jug, a gold puzzle, a burgundy bracelet, and a green notebook. How many non-magenta items do you see to the right of the red item?
Options:
(A) zero
(B) one
(C) two
(D) three
(E) four
(F) five
(G) six

Reasoning: The jug is the red item. The items to its right that are non-magenta are the
puzzle, the bracelet, and the notebook. There are 3 non-megenta items to
the right.

Answer: (D)""".strip()


cot_failures = [ ]
cot_successes = [ ]
for problem in tqdm(zero_shot_failures):
    r = bbh_cot(problem)
    if r.ok:
        cot_successes.append(r)
    else:
        cot_failures.append(r)

  0%|          | 0/249 [00:00<?, ?it/s]



---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

Cell In[44], line 46
     44 cot_successes = [ ]
     45 for problem in tqdm(zero_shot_failures):
---> 46     r = bbh_cot(problem)
     47     if r.ok:
     48         cot_successes.append(r)


Cell In[35], line 6, in bbh_cot(problem)
      4 _input = problem["input"]
      5 prompt = f"{COT_PROMPT}\n\nQuestion:{_input}\nReasoning:"
----> 6 completion = llama3(prompt, temperature=0, max_new_tokens=100, stop="\nQuestion:")
      7 reasoning, answer = completion.split("\nAnswer:", maxsplit=1)
      8 expected = problem["target"]


Cell In[3], line 15, in llama3(prompt, temperature, max_new_tokens, stop)
     12     generate_kwargs["temperature"] = None
     13 with torch.no_grad():
     14     # Llama 3.1 uses top-p sampling by default, which we disable here.
---> 15     example_outputs = model.generate(**inputs, **generate_kwargs, top_p=None)
     16 outputs = tokenizer.decode(example_outputs[0, inputs["input_ids"].shape[1]:])
     17 if stop is not None:


File ~/venv_ml/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)


File ~/venv_ml/lib/python3.10/site-packages/transformers/generation/utils.py:2215, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2207     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2208         input_ids=input_ids,
   2209         expand_size=generation_config.num_return_sequences,
   2210         is_encoder_decoder=self.config.is_encoder_decoder,
   2211         **model_kwargs,
   2212     )
   2214     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2215     result = self._sample(
   2216         input_ids,
   2217         logits_processor=prepared_logits_processor,
   2218         stopping_criteria=prepared_stopping_criteria,
   2219         generation_config=generation_config,
   2220         synced_gpus=synced_gpus,
   2221         streamer=streamer,
   2222         **model_kwargs,
   2223     )
   2225 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2226     # 11. prepare beam search scorer
   2227     beam_scorer = BeamSearchScorer(
   2228         batch_size=batch_size,
   2229         num_beams=generation_config.num_beams,
   (...)
   2234         max_length=generation_config.max_length,
   2235     )


File ~/venv_ml/lib/python3.10/site-packages/transformers/generation/utils.py:3206, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3203 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3205 # forward pass to get next token
-> 3206 outputs = self(**model_inputs, return_dict=True)
   3208 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
   3209 model_kwargs = self._update_model_kwargs_for_generation(
   3210     outputs,
   3211     model_kwargs,
   3212     is_encoder_decoder=self.config.is_encoder_decoder,
   3213 )


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()


File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1190, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1187 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1189 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1190 outputs = self.model(
   1191     input_ids=input_ids,
   1192     attention_mask=attention_mask,
   1193     position_ids=position_ids,
   1194     past_key_values=past_key_values,
   1195     inputs_embeds=inputs_embeds,
   1196     use_cache=use_cache,
   1197     output_attentions=output_attentions,
   1198     output_hidden_states=output_hidden_states,
   1199     return_dict=return_dict,
   1200     cache_position=cache_position,
   1201 )
   1203 hidden_states = outputs[0]
   1204 if self.config.pretraining_tp > 1:


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()


File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:945, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    933     layer_outputs = self._gradient_checkpointing_func(
    934         decoder_layer.__call__,
    935         hidden_states,
   (...)
    942         position_embeddings,
    943     )
    944 else:
--> 945     layer_outputs = decoder_layer(
    946         hidden_states,
    947         attention_mask=causal_mask,
    948         position_ids=position_ids,
    949         past_key_value=past_key_values,
    950         output_attentions=output_attentions,
    951         use_cache=use_cache,
    952         cache_position=cache_position,
    953         position_embeddings=position_embeddings,
    954     )
    956 hidden_states = layer_outputs[0]
    958 if use_cache:


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()


File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:676, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    673 hidden_states = self.input_layernorm(hidden_states)
    675 # Self Attention
--> 676 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    677     hidden_states=hidden_states,
    678     attention_mask=attention_mask,
    679     position_ids=position_ids,
    680     past_key_value=past_key_value,
    681     output_attentions=output_attentions,
    682     use_cache=use_cache,
    683     cache_position=cache_position,
    684     position_embeddings=position_embeddings,
    685     **kwargs,
    686 )
    687 hidden_states = residual + hidden_states
    689 # Fully Connected


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)


File ~/venv_ml/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()


File ~/venv_ml/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:577, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    575 else:
    576     cos, sin = position_embeddings
--> 577 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    579 if past_key_value is not None:
    580     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    581     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}


KeyboardInterrupt: 
len(cot_successes) / (len(cot_failures) + len(cot_successes))
0.6268656716417911