| | import numpy as np |
| | import pandas as pd |
| | import os |
| | import matplotlib.pyplot as plt |
| | from pathlib import Path |
| | from typing import Union |
| | import os |
| |
|
| | def delete_file(file_pt: Path) -> None: |
| | try: |
| | file_pt.unlink() |
| | except FileNotFoundError: |
| | pass |
| |
|
| |
|
| | def full_path(inp_dir_or_path: str) -> Path: |
| | """Returns full path""" |
| | return Path(inp_dir_or_path).expanduser().resolve() |
| |
|
| |
|
| | def mkdir_p(inp_dir_or_path: Union[str, Path]) -> Path: |
| | """Give a file/dir path, makes sure that all the directories exists""" |
| | inp_dir_or_path = full_path(inp_dir_or_path) |
| | if inp_dir_or_path.suffix: |
| | inp_dir_or_path.parent.mkdir(parents=True, exist_ok=True) |
| | else: |
| | inp_dir_or_path.mkdir(parents=True, exist_ok=True) |
| | return inp_dir_or_path |
| |
|
| | def similarity_between_sent(sent1_encoded, sent2_encoded): |
| | """report the avg. cosine similarity score b.w two pairs of sentences""" |
| | similarity_scores = [] |
| | for i in range(len(sent1_encoded)): |
| | similarity_scores.append(cosine_similarity( |
| | sent1_encoded[i], sent2_encoded[i])) |
| |
|
| | return np.mean(similarity_scores),similarity_scores |
| |
|
| |
|
| | def cosine_similarity(a, b): |
| | """ |
| | Takes 2 vectors a, b and returns the cosine similarity according |
| | to the definition of the dot product |
| | """ |
| | dot_product = np.dot(a, b) |
| | norm_a = np.linalg.norm(a) |
| | norm_b = np.linalg.norm(b) |
| | return dot_product / (norm_a * norm_b) |
| |
|
| | def load_data(path): |
| | if path.endswith(".csv"): |
| | data=pd.read_csv(path) |
| | else: |
| | data=pd.read_csv(path,sep="\t") |
| | |
| | if not isinstance(data,pd.DataFrame): |
| | raise ValueError("Data should be in pandas DataFrame format") |
| | return data |
| |
|
| | def read_data(dataset): |
| | if dataset == "mrpc": |
| | data = load_data("/home/yash/EMNLP-2024/data/mrpc.csv") |
| | data = data.copy() |
| | |
| | elif dataset == "qqp": |
| | data = load_data("/home/yash/EMNLP-2024/data/qoura.csv") |
| | data = data.copy().dropna() |
| | |
| | data.columns = data.columns.str.strip() |
| | data = data.rename(columns={"is_duplicate":"label",'question1':"sentence1","question2":"sentence2"}) |
| | |
| | elif dataset in ["paws","paw","wiki"]: |
| | path = "/home/yash/EMNLP-2024/data/paw_wiki.tsv" |
| | data = load_data(path) |
| | data = data.copy() |
| | |
| | else: |
| | ValueError("No dataset found.") |
| | |
| | return data |