| import sys |
| import numpy as np |
| import pandas as pd |
| from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics |
| from sklearn.metrics import f1_score |
| from collections import OrderedDict |
|
|
| EXPERIMENT_ROOT = "/mnt/lerna/experiments" |
|
|
|
|
| def softmax(x, axis=-1): |
| |
| x = x - np.max(x, axis=axis, keepdims=True) |
|
|
| |
| exps = np.exp(x) |
|
|
| |
| exps_sum = np.sum(exps, axis=axis, keepdims=True) |
|
|
| |
| softmax_probs = exps / exps_sum |
|
|
| return softmax_probs |
|
|
|
|
| def predictions_loader(predictions_path): |
| data = np.load(predictions_path)["arr_0"] |
| dataset_idx = data[:, -1] |
| labels = data[:, -2] |
| if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]): |
| data = data[:, :-2] |
| predictions = np.argmax(data, -1) |
| else: |
| labels = data[:, -2].astype(int) |
| predictions = data[:, -3].astype(int) |
| data = data[:, :-3] |
| return data, labels, predictions, dataset_idx |
|
|
|
|
| def compare_errors(): |
| """ |
| from scipy.stats import pearsonr, spearmanr |
| #idx = [x for x in strategy_correctness['first'] if x ==0] |
| spearmanr(strategy_correctness['first'], strategy_correctness['second']) |
| #SignificanceResult(statistic=0.5429413617297623, pvalue=0.0) |
| spearmanr(strategy_correctness['first'], strategy_correctness['last']) |
| #SignificanceResult(statistic=0.5005224326802595, pvalue=0.0) |
| |
| pearsonr(strategy_correctness['first'], strategy_correctness['second']) |
| #PearsonRResult(statistic=0.5429413617297617, pvalue=0.0) |
| pearsonr(strategy_correctness['first'], strategy_correctness['last']) |
| #PearsonRResult(statistic=0.5005224326802583, pvalue=0.0) |
| """ |
| for dataset in ["rvl_cdip_n_mp"]: |
| strategy_logits = {} |
| strategy_correctness = {} |
| for strategy in ["first", "second", "last"]: |
| path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" |
|
|
| strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path) |
| strategy_correctness[strategy] = (predictions == labels).astype(int) |
|
|
| print("Base accuracy of first: ", np.mean(strategy_correctness["first"])) |
| firstcorrectifsecondcorrect = [ |
| x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"]) |
| ] |
| print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}") |
| firstcorrectiflastcorrect = [ |
| x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"]) |
| ] |
| print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}") |
|
|
| firstcorrectifsecondorlastcorrect = [ |
| x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i]) |
| for i, x in enumerate(strategy_correctness["first"]) |
| ] |
| print( |
| f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}" |
| ) |
|
|
| |
| print("Base accuracy of second: ", np.mean(strategy_correctness["second"])) |
| secondcorrectiffirstcorrect = [ |
| x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"]) |
| ] |
| print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}") |
| secondcorrectiflastcorrect = [ |
| x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"]) |
| ] |
| print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}") |
|
|
| |
| print("Base accuracy of last: ", np.mean(strategy_correctness["last"])) |
| lastcorrectiffirstcorrect = [ |
| x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"]) |
| ] |
| print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}") |
| lastcorrectifsecondcorrect = [ |
| x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"]) |
| ] |
| print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}") |
|
|
|
|
| def review_one(path): |
| collect = OrderedDict() |
| try: |
| logits, labels, predictions, dataset_idx = predictions_loader(path) |
| except Exception as e: |
| print(f"something went wrong in inference loading {e}") |
| return |
| |
| y_correct = (predictions == labels).astype(int) |
| acc = np.mean(y_correct) |
| p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)]) |
|
|
| res = aurc_logits( |
| y_correct, p_hat, plot=False, get_cache=True, use_as_is=True |
| ) |
|
|
| collect["aurc"] = res["aurc"] |
| collect["accuracy"] = 100 * acc |
| collect["f1"] = 100 * f1_score(labels, predictions, average="weighted") |
| collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro") |
| collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True) |
|
|
| df = pd.DataFrame.from_dict([collect]) |
| |
| print(df.to_latex()) |
| print(df.to_string()) |
| return collect, res |
|
|
|
|
| def experiments_review(): |
| STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"] |
| for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]: |
| collect = {} |
| aurcs = [] |
| caches = [] |
| for strategy in STRATEGIES: |
| path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" |
| collect[strategy], res = review_one(path) |
| aurcs.append(res["aurc"]) |
| caches.append(res["cache"]) |
|
|
| df = pd.DataFrame.from_dict(collect, orient="index") |
| df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]] |
| print(df.to_latex()) |
| print(df.to_string()) |
| """ |
| subset = [0, 1, 2] |
| multi_aurc_plot( |
| [x for i, x in enumerate(caches) if i in subset], |
| [x for i, x in enumerate(STRATEGIES) if i in subset], |
| aurcs=[x for i, x in enumerate(aurcs) if i in subset], |
| ) |
| """ |
|
|
|
|
| if __name__ == "__main__": |
| from argparse import ArgumentParser |
|
|
| parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""") |
| DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz" |
| parser.add_argument( |
| "predictions_path", |
| type=str, |
| default=DEFAULT, |
| nargs="?", |
| help="path to predictions", |
| ) |
|
|
| args = parser.parse_args() |
| if args.predictions_path == DEFAULT: |
| experiments_review() |
| compare_errors() |
| sys.exit(1) |
|
|
| print(f"Running default experiment on {args.predictions_path}") |
| review_one(args.predictions_path) |
|
|