| |
| |
| """PyTorch LLaMA model.""" |
|
|
| import json |
| from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
|
|
| from transformers.generation.configuration_utils import GenerationConfig |
| from transformers.generation.logits_process import LogitsProcessorList |
| from transformers.generation.stopping_criteria import StoppingCriteriaList |
| from transformers.generation.utils import ( |
| GenerateBeamDecoderOnlyOutput, |
| GenerateBeamEncoderDecoderOutput, |
| GenerateDecoderOnlyOutput, |
| GenerateEncoderDecoderOutput |
| ) |
| from transformers.models.llama.modeling_llama import LlamaForCausalLM |
| from transformers.utils import logging |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.generation.streamers import BaseStreamer |
|
|
| logger = logging.get_logger(__name__) |
|
|
| GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] |
| GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] |
| GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] |
|
|
|
|
| class FunctionaryForCausalLM(LlamaForCausalLM): |
|
|
| def generate_tool_use( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| synced_gpus: Optional[bool] = None, |
| assistant_model: Optional["PreTrainedModel"] = None, |
| streamer: Optional["BaseStreamer"] = None, |
| negative_prompt_ids: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
|
|
| tokenizer = kwargs.pop("tokenizer", None) |
|
|
| results = self.generate( |
| inputs=inputs, |
| generation_config=generation_config, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| synced_gpus=synced_gpus, |
| assistant_model=assistant_model, |
| streamer=streamer, |
| negative_prompt_ids=negative_prompt_ids, |
| negative_prompt_attention_mask=negative_prompt_attention_mask, |
| **kwargs, |
| ) |
|
|
| input_ids = kwargs.pop("input_ids") |
| function_call_token = "<function=" |
| |
| correct_results = [] |
| for input_id, result in zip(input_ids, results): |
| final_output_json = {"role": "assistant", "content": None, "tool_calls": None} |
| tool_calls = [] |
| raw_output_str = tokenizer.decode(result[len(input_id):].cpu()) |
| has_text = False if raw_output_str.startswith(function_call_token) else True |
| chunks = raw_output_str.split(function_call_token) |
| for i, chunk in enumerate(chunks): |
| if len(chunk) == 0: |
| continue |
|
|
| chunk = chunk.replace(tokenizer.pad_token, "") |
|
|
| if i == 0 and has_text is not False: |
| final_output_json["content"] = chunk.removesuffix("<|eom_id|>").removesuffix("<|eot_id|>") |
| else: |
| tool_calls.append( |
| { |
| "name": chunk[: chunk.index(">{")], |
| "arguments": chunk[chunk.index(">{") + 1: ].removesuffix("<|eom_id|>").removesuffix("</function>") |
| } |
| ) |
| if len(tool_calls) > 0: |
| final_output_json["tool_calls"] = tool_calls |
| final_output_str = json.dumps(final_output_json, indent=4) |
| final_output_ids = tokenizer(final_output_str, add_special_tokens=False)["input_ids"] |
| correct_results.append( |
| torch.cat( |
| (result[:len(input_id)].cpu(), torch.tensor(final_output_ids)) |
| ) |
| ) |
| max_len = max([tensor.shape[0] for tensor in correct_results]) |
| correct_results = [ |
| torch.nn.functional.pad( |
| correct_result, (0, max_len - correct_result.shape[0]), value=tokenizer.eos_token_id |
| ) for correct_result in correct_results |
| ] |
| correct_results = torch.stack(correct_results) |
| |
| return correct_results |