| import argparse |
| import logging |
| import os |
| from collections import OrderedDict |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| import pandas as pd |
| import torch |
| import torchmetrics |
|
|
| from configs.args_base import get_args |
| from data import build_dataloader |
| from models.MIQA_base import get_timm_model, get_torch_model |
| from models.RA_MIQA import RegionVisionTransformer |
| from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES |
| from train import AverageMeter |
| from utils.hf_download_utils import ensure_checkpoint_from_hf |
|
|
|
|
| def get_checkpoint_path(model_name: str, train_dataset: str, metric_type: str = "composite") -> str: |
| base_dir = Path("models") / "checkpoints" / f"{metric_type}_metric" |
| base_dir.mkdir(parents=True, exist_ok=True) |
| filename = MODEL_FILENAMES[metric_type][train_dataset][model_name] |
| return str(base_dir / filename) |
|
|
|
|
| def get_available_models(train_dataset: str, metric_type: str) -> List[str]: |
| """ |
| Get list of available models for a specific training dataset and metric type. |
| |
| This helper function is useful for validation and for providing helpful error messages |
| when a user requests a model that isn't available for their chosen configuration. |
| |
| Args: |
| train_dataset: Training dataset type (cls, det, ins) |
| metric_type: Training metric objective (composite, consistency, accuracy) |
| |
| Returns: |
| List of available model names for this configuration |
| """ |
| if metric_type in MODEL_FILENAMES: |
| if train_dataset in MODEL_FILENAMES[metric_type]: |
| return list(MODEL_FILENAMES[metric_type][train_dataset].keys()) |
| return [] |
|
|
|
|
| def ensure_model_weights(model_name: str, train_dataset: str, metric_type: str, |
| logger: logging.Logger) -> Optional[str]: |
| """ |
| Ensure model weights exist, download if necessary. |
| |
| This function implements a caching strategy: it first checks if the checkpoint already |
| exists locally. If not, it downloads it from Hugging Face Hub. This means |
| the first run will download weights, but subsequent runs will be much faster. |
| |
| Args: |
| model_name: Name of the model architecture |
| train_dataset: Training dataset type (cls, det, or ins) |
| metric_type: Training metric objective (composite, consistency, or accuracy) |
| logger: Logger instance for status messages |
| |
| Returns: |
| Path to checkpoint if successful, None if weights cannot be obtained |
| """ |
| |
| checkpoint_path = get_checkpoint_path(model_name, train_dataset, metric_type) |
|
|
| |
| if os.path.exists(checkpoint_path): |
| logger.info(f"✓ Found existing checkpoint: {checkpoint_path}") |
| return checkpoint_path |
|
|
| |
| logger.info(f"Checkpoint not found at {checkpoint_path}") |
|
|
| |
| if metric_type not in MODEL_FILENAMES: |
| logger.error(f"✗ Metric type '{metric_type}' not recognized") |
| logger.error(f" Available metric types: {list(MODEL_FILENAMES.keys())}") |
| return None |
|
|
| if train_dataset not in MODEL_FILENAMES[metric_type]: |
| logger.error(f"✗ Train dataset '{train_dataset}' not available for metric type '{metric_type}'") |
| return None |
|
|
| if model_name not in MODEL_FILENAMES[metric_type][train_dataset]: |
| available_models = get_available_models(train_dataset, metric_type) |
| logger.error(f"✗ Model '{model_name}' not available for {train_dataset}/{metric_type}") |
| logger.error(f" Available models: {available_models}") |
| return None |
|
|
| filename = MODEL_FILENAMES[metric_type][train_dataset][model_name] |
| logger.info( |
| f"Attempting to download checkpoint from Hugging Face: " |
| f"repo={HF_REPO_ID}, file={filename}, rev={HF_REVISION}" |
| ) |
| try: |
| local_path = ensure_checkpoint_from_hf( |
| repo_id=HF_REPO_ID, |
| filename=filename, |
| local_dir=str(Path("models") / "checkpoints" / f"{metric_type}_metric"), |
| revision=HF_REVISION, |
| ) |
| logger.info("✓ Successfully downloaded checkpoint from Hugging Face") |
| return local_path |
| except Exception as e: |
| logger.error(f"✗ Failed to download checkpoint from Hugging Face: {e}") |
| return None |
|
|
| def load_model_weights(model: torch.nn.Module, weights_path: str, args: argparse.Namespace, |
| logger: logging.Logger) -> bool: |
| """ |
| Load model weights from checkpoint file. |
| |
| This function handles the actual loading of weights into the model, with proper error |
| handling and support for different checkpoint formats (direct state dict or wrapped |
| in a dictionary with metadata). |
| |
| Args: |
| model: The model to load weights into |
| weights_path: Path to the checkpoint file |
| args: Command line arguments |
| logger: Logger instance |
| |
| Returns: |
| True if weights loaded successfully, False otherwise |
| """ |
| if not os.path.isfile(weights_path): |
| logger.error(f"✗ Checkpoint file not found: '{weights_path}'") |
| return False |
|
|
| logger.info(f"Loading checkpoint from '{weights_path}'") |
|
|
| try: |
| |
| checkpoint = torch.load(weights_path, map_location="cpu") |
|
|
| |
| |
| state_dict = checkpoint.get('state_dict', checkpoint) |
|
|
| |
| |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k.replace('module.', '') if k.startswith('module.') else k |
| new_state_dict[name] = v |
|
|
| |
| model.load_state_dict(new_state_dict) |
| logger.info(f"✓ Successfully loaded checkpoint") |
|
|
| |
| if 'epoch' in checkpoint: |
| logger.info(f" Checkpoint epoch: {checkpoint['epoch']}") |
| if 'best_srcc' in checkpoint: |
| logger.info(f" Best SRCC: {checkpoint['best_srcc']:.4f}") |
| if 'metric_type' in checkpoint: |
| logger.info(f" Metric type: {checkpoint['metric_type']}") |
|
|
| return True |
|
|
| except Exception as e: |
| logger.error(f"✗ Error loading checkpoint: {str(e)}") |
| return False |
|
|
|
|
| def create_model(model_name: str, args: argparse.Namespace, logger: logging.Logger) -> torch.nn.Module: |
| """ |
| Create model instance based on model name. |
| |
| This function handles the instantiation of different model architectures. It includes |
| special handling for the RegionVisionTransformer (RA_MIQA) which has a different |
| initialization process than standard vision models. |
| |
| Args: |
| model_name: Name of the model architecture |
| args: Command line arguments |
| logger: Logger instance |
| |
| Returns: |
| Initialized model (without loaded weights yet) |
| """ |
| |
| if model_name == 'ra_miqa': |
| logger.info(f"Creating RA_MIQA Model") |
| model = RegionVisionTransformer( |
| base_model_name='vit_small_patch16_224', |
| pretrained=True, |
| mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', |
| checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' |
| ) |
| else: |
| |
| try: |
| logger.info(f"Creating model from PyTorch: {model_name}") |
| model = get_torch_model(model_name=model_name, pretrained=False, num_classes=1) |
| except Exception as e: |
| logger.info(f"PyTorch model not found, trying timm library: {model_name}") |
| try: |
| model = get_timm_model(model_name=model_name, pretrained=False, num_classes=1) |
| except Exception as e: |
| logger.error(f"✗ Failed to create model: {str(e)}") |
| raise |
|
|
| return model |
|
|
|
|
| @torch.no_grad() |
| def inference(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, |
| args: argparse.Namespace, criterion: torch.nn.Module, |
| logger: logging.Logger) -> Dict: |
| """ |
| Run inference on validation set and compute metrics. |
| |
| This function performs the actual evaluation of the model on the test dataset. It runs |
| in evaluation mode with no gradient computation, processes all batches, and computes |
| standard image quality assessment metrics (SRCC, PLCC, KLCC). |
| |
| Args: |
| val_loader: DataLoader for validation data |
| model: Model to evaluate |
| args: Command line arguments |
| criterion: Loss function (MSE) |
| logger: Logger instance |
| |
| Returns: |
| Dictionary containing predictions, ground truth, and computed metrics |
| """ |
| |
| model.eval() |
| val_dataset_len = len(val_loader.dataset) |
| val_loader_len = len(val_loader) |
|
|
| |
| batch_time = AverageMeter('Time', ':6.3f') |
| losses = AverageMeter('Loss', ':.4e') |
|
|
| |
| temp_pred_scores = [] |
| temp_gt_scores = [] |
| temp_img_names = [] |
|
|
| logger.info(f"Starting inference on {val_dataset_len} images...") |
|
|
| for i, batch in enumerate(val_loader): |
| |
| image_cropped = batch['image_cropped'].cuda(args.gpu, non_blocking=True) |
| image_resized = batch['image_resized'].cuda(args.gpu, non_blocking=True) |
| target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1) |
|
|
| |
| output = model(image_cropped, image_resized) |
| loss = criterion(output.view(-1), target.view(-1)) |
| losses.update(loss.item(), target.size(0)) |
|
|
| |
| temp_pred_scores.append(output.view(-1)) |
| temp_gt_scores.append(target.view(-1)) |
| temp_img_names.extend(batch['image_name']) |
|
|
| |
| if i % args.print_freq == 0: |
| logger.info( |
| f" [{i}/{val_loader_len}] " |
| f"Loss: {losses.val:.4f} (avg: {losses.avg:.4f})" |
| ) |
|
|
| |
| final_preds = torch.cat(temp_pred_scores) |
| final_grotruth = torch.cat(temp_gt_scores) |
|
|
| |
| if hasattr(args, 'patch_num') and args.patch_num > 1: |
| logger.info(f"Averaging predictions over {args.patch_num} patches per image") |
| preds_matrix = final_preds.view(-1, args.patch_num) |
| final_preds = preds_matrix.mean(dim=-1).squeeze() |
| final_grotruth = final_grotruth.view(-1, args.patch_num).mean(dim=-1).squeeze() |
|
|
| logger.info( |
| f"Dataset size: {val_dataset_len}, " |
| f"Predictions shape: {final_preds.shape}, " |
| f"Ground truth shape: {final_grotruth.shape}" |
| ) |
|
|
| |
| if torch.isnan(final_preds).any() or torch.isinf(final_preds).any(): |
| raise ValueError("Found NaN or inf values in predictions") |
| if torch.isnan(final_grotruth).any() or torch.isinf(final_grotruth).any(): |
| raise ValueError("Found NaN or inf values in ground truth") |
|
|
| |
| |
| test_srcc = torchmetrics.functional.spearman_corrcoef(final_preds, final_grotruth).item() |
| |
| test_plcc = torchmetrics.functional.pearson_corrcoef(final_preds, final_grotruth).item() |
| |
| test_klcc = torchmetrics.functional.kendall_rank_corrcoef(final_preds, final_grotruth).item() |
|
|
| |
| results = { |
| 'image_names': temp_img_names, |
| 'predictions': final_preds.cpu().numpy().tolist(), |
| 'ground_truth': final_grotruth.cpu().numpy().tolist(), |
| 'metrics': { |
| 'srcc': test_srcc, |
| 'plcc': test_plcc, |
| 'klcc': test_klcc, |
| 'loss': losses.avg |
| } |
| } |
|
|
| return results |
|
|
|
|
| def save_results(results: Dict, model_name: str, train_dataset: str, |
| test_dataset: str, metric_type: str, output_dir: str, |
| logger: logging.Logger) -> None: |
| """ |
| Save inference results to CSV file with detailed metrics. |
| |
| This function saves both detailed per-image results and prints a summary of the |
| overall performance metrics. The filename includes all relevant configuration |
| details for easy identification. |
| |
| Args: |
| results: Results dictionary from inference |
| model_name: Name of the model |
| train_dataset: Training dataset type |
| test_dataset: Test dataset name |
| metric_type: Training metric objective |
| output_dir: Base directory to save results |
| logger: Logger instance |
| """ |
| |
| eval_dir = os.path.join(output_dir, 'evaluations') |
| os.makedirs(eval_dir, exist_ok=True) |
|
|
| |
| csv_data = [] |
| for img_name, pred, gt in zip(results['image_names'], |
| results['predictions'], |
| results['ground_truth']): |
| csv_data.append({ |
| 'image_name': img_name, |
| 'prediction': pred, |
| 'ground_truth': gt, |
| 'absolute_error': abs(pred - gt) |
| }) |
|
|
| |
| csv_filename = f"{model_name}_{train_dataset}_{metric_type}_on_{test_dataset}.csv" |
| csv_path = os.path.join(eval_dir, csv_filename) |
|
|
| |
| df = pd.DataFrame(csv_data) |
| df.to_csv(csv_path, index=False) |
| logger.info(f"Detailed results saved to: {csv_path}") |
|
|
| |
| logger.info("\n" + "=" * 70) |
| logger.info("EVALUATION METRICS") |
| logger.info("=" * 70) |
| logger.info(f"Model: {model_name}") |
| logger.info(f"Trained on: {train_dataset}") |
| logger.info(f"Metric type: {metric_type}") |
| logger.info(f"Tested on: {test_dataset}") |
| logger.info("-" * 70) |
| logger.info(f"SRCC (Spearman): {results['metrics']['srcc']:.4f}") |
| logger.info(f"PLCC (Pearson): {results['metrics']['plcc']:.4f}") |
| logger.info(f"KLCC (Kendall): {results['metrics']['klcc']:.4f}") |
| logger.info(f"MSE Loss: {results['metrics']['loss']:.4f}") |
| logger.info("=" * 70 + "\n") |
|
|
| def main(args: argparse.Namespace, logger: logging.Logger) -> None: |
| """ |
| Main inference pipeline orchestrating all steps. |
| |
| This function coordinates the entire evaluation process: validating inputs, |
| ensuring model weights are available, loading data, creating and loading the model, |
| running inference, and saving results. |
| |
| Args: |
| args: Command line arguments |
| logger: Logger instance |
| """ |
| |
| if not args.model_name: |
| raise ValueError("Please specify --model_name") |
| if not args.train_dataset: |
| raise ValueError("Please specify --train_dataset (cls, det, or ins)") |
| if not args.test_dataset: |
| raise ValueError("Please specify --test_dataset") |
| if not args.metric_type: |
| raise ValueError("Please specify --metric_type (composite, consistency, or accuracy)") |
|
|
| logger.info(f"\nStarting MIQA Inference Pipeline") |
| logger.info(f"Model: {args.model_name}") |
| logger.info(f"Trained on: {args.train_dataset}") |
| logger.info(f"Metric type: {args.metric_type}") |
| logger.info(f"Testing on: {args.test_dataset}") |
|
|
| |
| checkpoint_path = ensure_model_weights(args.model_name, args.train_dataset, |
| args.metric_type, logger) |
|
|
| if checkpoint_path is None: |
| logger.error("Cannot proceed without model weights") |
| return |
|
|
| |
| logger.info(f"\nLoading {args.test_dataset} dataset...") |
| args.dataset = args.test_dataset |
|
|
| args.eval_only = True |
| val_dataset = build_dataloader.build_dataset(args) |
|
|
| val_loader = torch.utils.data.DataLoader( |
| val_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.workers, |
| pin_memory=True |
| ) |
| logger.info(f"✓ Loaded {len(val_dataset)} images with {args.workers} workers") |
|
|
| |
| logger.info(f"\nCreating model architecture...") |
| args.arch = args.model_name |
| model = create_model(args.model_name, args, logger) |
|
|
| |
| if not load_model_weights(model, checkpoint_path, args, logger): |
| logger.error("Failed to load model weights") |
| return |
|
|
| |
| if args.gpu is not None and torch.cuda.is_available(): |
| model = model.cuda(args.gpu) |
| logger.info(f"✓ Model moved to GPU {args.gpu}") |
| else: |
| logger.warning("GPU not available, using CPU (this will be slower)") |
|
|
| |
| criterion = torch.nn.MSELoss() |
|
|
| |
| logger.info(f"\nRunning inference...") |
| results = inference(val_loader, model, args, criterion, logger) |
|
|
| |
| save_results(results, args.model_name, args.train_dataset, |
| args.test_dataset, args.metric_type, args.output_dir, logger) |
|
|
|
|
| if __name__ == '__main__': |
| |
| parser = get_args() |
| parser.add_argument('--model_name', type=str, required=True, |
| choices=['ra_miqa'], |
| help='Model architecture (Hub registry currently ships RA-MIQA only)') |
| parser.add_argument('--train_dataset', type=str, required=True, |
| choices=['cls', 'det', 'ins'], |
| help='Dataset type the model was trained on (cls=classification, det=detection, ins=instance)') |
| parser.add_argument('--test_dataset', type=str, required=True, |
| help='Name of the dataset to test on') |
| parser.add_argument('--metric_type', type=str, required=True, |
| choices=['composite', 'consistency', 'accuracy'], |
| help='Training metric objective used (composite=both metrics, consistency=consistency-focused, accuracy=accuracy-focused)') |
| parser.add_argument('--output_dir', type=str, default='outputs', |
| help='Directory to save results (default: outputs)') |
|
|
| args = parser.parse_args() |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| log_filename = f"inference_{args.model_name}_{args.train_dataset}_{args.metric_type}.log" |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler(os.path.join(args.output_dir, log_filename)), |
| logging.StreamHandler() |
| ] |
| ) |
|
|
| logger = logging.getLogger('miqa_inference') |
|
|
| |
| try: |
| main(args, logger) |
| except Exception as e: |
| logger.error(f"Error during inference: {str(e)}", exc_info=True) |
| raise |