| import os |
| import numpy as np |
| from tqdm import tqdm |
| import torch |
| from datasets import load_dataset, ClassLabel |
| from datasets import Features, Array3D |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
| from metrics import apply_metrics |
|
|
|
|
| def process_label_ids(batch, remapper, label_column="label"): |
| batch[label_column] = [remapper[label_id] for label_id in batch[label_column]] |
| return batch |
|
|
|
|
| CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None |
|
|
|
|
| def main(args): |
| dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR) |
| if args.dataset == "rvl_cdip": |
| dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) |
| batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000 |
|
|
| feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
| model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
|
|
| label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()} |
| data_idx2label = dict(zip(enumerate(dataset.features["label"].names))) |
| data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)} |
| model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) |
| diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]] |
|
|
| if diff: |
| print(f"aligning labels {diff}") |
| print(f"model labels: {model_idx2label}") |
| print(f"data labels: {data_idx2label}") |
| print(f"Remapping to {label2idx}") |
|
|
| remapper = {} |
| for k, v in label2idx.items(): |
| if k in data_label2idx: |
| remapper[data_label2idx[k]] = v |
|
|
| print(remapper) |
| new_features = Features( |
| { |
| **{k: v for k, v in dataset.features.items() if k != "label"}, |
| "label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())), |
| } |
| ) |
|
|
| dataset = dataset.map( |
| lambda example: process_label_ids(example, remapper), |
| features=new_features, |
| batched=True, |
| batch_size=batch_size, |
| desc="Aligning the labels", |
| ) |
|
|
| features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))}) |
|
|
| encoded_dataset = dataset.map( |
| lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]), |
| batched=True, |
| batch_size=batch_size, |
| features=features, |
| ) |
| encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"]) |
| BATCH_SIZE = 16 |
| dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE) |
|
|
| all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros( |
| len(encoded_dataset), dtype=int |
| ) |
|
|
| count = 0 |
| for i, batch in tqdm(enumerate(dataloader)): |
| with torch.no_grad(): |
| outputs = model(batch["pixel_values"]) |
| logits = outputs.logits |
| all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy() |
| all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy() |
| count += len(batch["label"]) |
|
|
| all_references = np.array(all_references) |
| all_logits = np.array(all_logits) |
| results = apply_metrics(all_references, all_logits) |
| print(results) |
|
|
|
|
| if __name__ == "__main__": |
| from argparse import ArgumentParser |
|
|
| parser = ArgumentParser("""DiT inference on dataset test set""") |
| parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated") |
| args = parser.parse_args() |
|
|
| main(args) |
|
|