| | import torch |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| | import os |
| | from tqdm import tqdm |
| | import pandas as pd |
| | import time |
| | import sys |
| | from datasets import load_dataset |
| | from src.utils import read_data |
| |
|
| | class NLLBTranslator: |
| | def __init__(self, model_name="facebook/nllb-200-3.3B"): |
| | """ |
| | Initialize the NLLB model and tokenizer for translation |
| | """ |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) |
| | |
| | def _get_nllb_code(self, language: str) -> str: |
| | """ |
| | Maps common language names to NLLB language codes. |
| | |
| | Args: |
| | language (str): Common language name (case-insensitive) |
| | |
| | Returns: |
| | str: NLLB language code or None if language not found |
| | |
| | Examples: |
| | >>> get_nllb_code("english") |
| | 'eng_Latn' |
| | >>> get_nllb_code("Chinese") |
| | 'zho_Hans' |
| | """ |
| | language_mapping = { |
| | |
| | "english": "eng_Latn", |
| | "eng": "eng_Latn", |
| | "en": "eng_Latn", |
| | |
| | |
| | "hindi": "hin_Deva", |
| | "hi": "hin_Deva", |
| | |
| | |
| | "french": "fra_Latn", |
| | "fr": "fra_Latn", |
| | |
| | |
| | "korean": "kor_Hang", |
| | "ko": "kor_Hang", |
| | |
| | |
| | "spanish": "spa_Latn", |
| | "es": "spa_Latn", |
| | |
| | |
| | "chinese": "zho_Hans", |
| | "chinese simplified": "zho_Hans", |
| | "chinese traditional": "zho_Hant", |
| | "mandarin": "zho_Hans", |
| | "zh-cn": "zho_Hans", |
| | |
| | |
| | "japanese": "jpn_Jpan", |
| | "jpn": "jpn_Jpan", |
| | "ja": "jpn_Jpan", |
| | |
| | |
| | "german": "deu_Latn", |
| | "de": "deu_Latn" |
| | } |
| | |
| | |
| | normalized_input = language.lower().strip() |
| | |
| | |
| | return language_mapping.get(normalized_input) |
| | |
| | def add_language_code(self, name_code_dict, language, code): |
| | |
| | |
| | """ |
| | Adds a language code to the dictionary if it is not already present. |
| | |
| | Args: |
| | name_code_dict (dict): Dictionary of language names to codes |
| | language (str): Language name |
| | code (str): Language code |
| | |
| | Returns: |
| | dict: Updated dictionary |
| | """ |
| | |
| | normalized_language = language.lower().strip() |
| | |
| | |
| | if normalized_language not in name_code_dict: |
| | name_code_dict[normalized_language] = code |
| | |
| | return name_code_dict |
| |
|
| |
|
| | def translate(self, text, source_lang="eng_Latn", target_lang="fra_Latn",batch_size=None): |
| | """ |
| | Translate text from source language to target language |
| | |
| | Args: |
| | text (str): Text to translate |
| | source_lang (str): Source language code |
| | target_lang (str): Target language code |
| | |
| | Returns: |
| | str: Translated text |
| | """ |
| | |
| | inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device) |
| |
|
| | |
| | source_lang = self._get_nllb_code(source_lang) |
| | target_lang = self._get_nllb_code(target_lang) |
| | |
| | forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(target_lang) |
| |
|
| | |
| | translated_tokens = self.model.generate( |
| | **inputs, |
| | max_length=256, |
| | num_beams=5, |
| | temperature=0.5, |
| | do_sample=True, |
| | forced_bos_token_id=forced_bos_token_id, |
| | ) |
| |
|
| | |
| | if translated_tokens.shape[0] == 1: |
| | translation = self.tokenizer.decode(translated_tokens[0], skip_special_tokens=True) |
| | else: |
| | translation = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) |
| |
|
| | return translation |
| |
|
| | def main(): |
| | |
| | print("Loading model and tokenizer...") |
| | translator = NLLBTranslator() |
| |
|
| | |
| | texts = [ |
| | "Hello, how are you?", |
| | "This is a test of the NLLB translation model.", |
| | "Machine learning is fascinating." |
| | ] |
| | print("\nTranslating texts from English to French:") |
| | trt=translation = translator.translate(texts,target_lang="fr",batch_size=2) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|