| import os |
| from tqdm import tqdm |
| import pandas as pd |
| import numpy as np |
| import torch |
| from datasets import load_dataset, logging |
| from datasets import Features, Value, Image, Sequence, Array3D, Array4D |
| import evaluate |
| from metrics import apply_metrics |
|
|
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
| logger = logging.get_logger(__name__) |
|
|
| from mapping_functions import ( |
| pdf_to_pixelvalues_extractor, |
| nativepdf_to_pixelvalues_extractor, |
| ) |
| from inference_methods import InferenceMethod |
|
|
| EXPERIMENT_ROOT = "/mnt/lerna/experiments" |
|
|
|
|
| def load_base_model(): |
| feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
| model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
| return model, feature_extractor |
|
|
|
|
| def logits_monitor(args, running_logits, references, predictions, identifier="a"): |
| output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}-i{identifier}.npz" |
|
|
| raw_output = torch.cat( |
| [ |
| torch.cat(running_logits, dim=0).cpu(), |
| torch.Tensor(references).unsqueeze(1), |
| torch.Tensor(predictions).unsqueeze(1), |
| torch.Tensor(np.arange(int(identifier) - len(references), int(identifier))).unsqueeze(1), |
| ], |
| dim=1, |
| ) |
| np.savez_compressed(output_path, raw_output.cpu().data.numpy()) |
| tqdm.write("saved raw test outputs to {}".format(output_path)) |
|
|
|
|
| def monitor_cleanup(args, buffer_keys): |
| """ |
| This merges all previous buffers to 1 file |
| """ |
| output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}" |
|
|
| for i, identifier in enumerate(buffer_keys): |
| identifier_path = f"{output_path}-i{identifier}.npz" |
| saved = np.load(identifier_path)["arr_0"] |
| if i == 0: |
| catted = saved |
| else: |
| catted = np.concatenate([catted, saved]) |
| out_path = f"{output_path}-final.npz" |
| np.savez_compressed(out_path, catted) |
| tqdm.write("saved raw test outputs to {}".format(out_path)) |
| |
| for i, identifier in enumerate(buffer_keys): |
| identifier_path = f"{output_path}-i{identifier}.npz" |
| os.remove(identifier_path) |
|
|
|
|
| def main(args): |
| testds = load_dataset( |
| args.dataset, |
| cache_dir="/mnt/lerna/data/HFcache", |
| split="test", |
| revision=None if args.dataset != "bdpc/rvl_cdip_mp" else "d3a654c9f63f14d0aaa94e08aa30aa3dc20713c1", |
| ) |
|
|
| if args.downsampling: |
| testds = testds.select(list(range(0, args.downsampling))) |
|
|
| model = AutoModelForImageClassification.from_pretrained(args.model) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.to(device) |
| label2idx = {label: i for label, i in model.config.label2id.items()} |
| print(label2idx) |
|
|
| data_idx2label = dict(enumerate(testds.features["labels"].names)) |
| model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) |
| diff = [i for i in range(len(data_idx2label)) if data_idx2label[i] != model_idx2label[i]] |
| if diff: |
| print(f"aligning labels {diff}") |
| testds = testds.align_labels_with_mapping(label2idx, "labels") |
|
|
| inference_method = InferenceMethod[args.inference_method.upper()] |
| dummy_inference_method = inference_method |
| feature_extractor = AutoFeatureExtractor.from_pretrained(args.model) |
|
|
| features = { |
| **{k: v for k, v in testds.features.items() if k in ["labels", "pixel_values", "id"]}, |
| "pages": Value(dtype="int32"), |
| "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)), |
| } |
| if not "sample" in inference_method.scope: |
| features["pixel_values"] = Array4D(dtype="float32", shape=(None, 3, 224, 224)) |
| dummy_inference_method = InferenceMethod["max_confidence".upper()] |
| features = Features(features) |
|
|
| remove_columns = ["file"] |
| if args.dataset == "bdpc/rvl_cdip_mp": |
| image_preprocessor = lambda batch: pdf_to_pixelvalues_extractor( |
| batch, feature_extractor, dummy_inference_method |
| ) |
| encoded_testds = testds.map( |
| image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues" |
| ) |
| else: |
| image_preprocessor = lambda batch: nativepdf_to_pixelvalues_extractor( |
| batch, feature_extractor, dummy_inference_method |
| ) |
| encoded_testds = testds.map( |
| image_preprocessor, |
| features=features, |
| remove_columns=remove_columns, |
| desc="pdf_to_pixelvalues", |
| batch_size=10, |
| ) |
| |
|
|
| |
| print(f"Before filtering: {len(encoded_testds)}") |
| more_complex_filter = lambda example: example["pages"] != 0 and not np.any(np.isnan(example["pixel_values"])) |
| good_indices = [i for i, x in tqdm(enumerate(encoded_testds), desc="filter") if more_complex_filter(x)] |
| encoded_testds = encoded_testds.select(good_indices) |
| print(f"After filtering: {len(encoded_testds)}") |
|
|
| metric = evaluate.load("accuracy") |
|
|
| |
| encoded_testds.set_format(type="torch", columns=["pixel_values", "labels"]) |
| args.batch_size = args.batch_size if "sample" in inference_method.scope else 1 |
| dataloader = torch.utils.data.DataLoader(encoded_testds, batch_size=args.batch_size) |
|
|
| running_logits = [] |
| predictions, references = [], [] |
| buffer_references = [] |
| buffer_predictions = [] |
| buffer = 0 |
| BUFFER_SIZE = 5000 |
| buffer_keys = [] |
| for i, batch in tqdm(enumerate(dataloader), desc="Inference loop"): |
| with torch.no_grad(): |
| batch["labels"] = batch["labels"].to(device) |
| batch["pixel_values"] = batch["pixel_values"].to(device) |
| if "sample" in inference_method.scope: |
| outputs = model(batch["pixel_values"].to(device)) |
| logits = outputs.logits |
| buffer_predictions.extend(logits.argmax(-1).tolist()) |
| buffer_references.extend(batch["labels"].tolist()) |
| running_logits.append(logits) |
| else: |
| try: |
| page_logits = model(batch["pixel_values"][0]).logits |
| except Exception as e: |
| print(f"something went wrong in inference {e}") |
| continue |
| prediction = inference_method.apply_decision_strategy(page_logits) |
| buffer_predictions.append(prediction.tolist()) |
| buffer_references.extend(batch["labels"].tolist()) |
| running_logits.append(page_logits.mean(0).unsqueeze(0)) |
|
|
| buffer += args.batch_size |
| if buffer >= BUFFER_SIZE: |
| predictions.extend(buffer_predictions) |
| references.extend(buffer_references) |
| logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i)) |
| buffer_keys.append(str(i)) |
| running_logits = [] |
| buffer_references = [] |
| buffer_predictions = [] |
| buffer = 0 |
|
|
| if buffer != 0: |
| predictions.extend(buffer_predictions) |
| references.extend(buffer_references) |
| logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i)) |
| buffer_keys.append(str(i)) |
|
|
| accuracy = metric.compute(references=references, predictions=predictions) |
| print(f"Accuracy on this inference configuration {inference_method}:", accuracy) |
| monitor_cleanup(args, buffer_keys) |
|
|
|
|
| if __name__ == "__main__": |
| from argparse import ArgumentParser |
|
|
| parser = ArgumentParser("""Test different inference strategies to classify a document""") |
| parser.add_argument( |
| "inference_method", |
| type=str, |
| default="first", |
| nargs="?", |
| help="how to evaluate DiT on RVL-CDIP_multi", |
| ) |
| parser.add_argument("-s", dest="downsampling", type=int, default=0, help="number of testset samples") |
| parser.add_argument("-d", dest="dataset", type=str, default="bdpc/rvl_cdip_mp", help="the dataset to be evaluated") |
| parser.add_argument( |
| "-m", |
| dest="model", |
| type=str, |
| default="microsoft/dit-base-finetuned-rvlcdip", |
| help="the model checkpoint to be evaluated", |
| ) |
| parser.add_argument("-b", dest="batch_size", type=int, default=16, help="batch size") |
| parser.add_argument( |
| "-k", |
| dest="keep_in_memory", |
| default=False, |
| action="store_true", |
| help="do not cache operations (for testing)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| main(args) |
|
|