miqa / evaluate.py
xiaoqi-wang's picture
Upload evaluate.py with huggingface_hub
e29b006 verified
import argparse
import logging
import os
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
import torch
import torchmetrics
from configs.args_base import get_args
from data import build_dataloader
from models.MIQA_base import get_timm_model, get_torch_model
from models.RA_MIQA import RegionVisionTransformer
from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES
from train import AverageMeter
from utils.hf_download_utils import ensure_checkpoint_from_hf
def get_checkpoint_path(model_name: str, train_dataset: str, metric_type: str = "composite") -> str:
base_dir = Path("models") / "checkpoints" / f"{metric_type}_metric"
base_dir.mkdir(parents=True, exist_ok=True)
filename = MODEL_FILENAMES[metric_type][train_dataset][model_name]
return str(base_dir / filename)
def get_available_models(train_dataset: str, metric_type: str) -> List[str]:
"""
Get list of available models for a specific training dataset and metric type.
This helper function is useful for validation and for providing helpful error messages
when a user requests a model that isn't available for their chosen configuration.
Args:
train_dataset: Training dataset type (cls, det, ins)
metric_type: Training metric objective (composite, consistency, accuracy)
Returns:
List of available model names for this configuration
"""
if metric_type in MODEL_FILENAMES:
if train_dataset in MODEL_FILENAMES[metric_type]:
return list(MODEL_FILENAMES[metric_type][train_dataset].keys())
return []
def ensure_model_weights(model_name: str, train_dataset: str, metric_type: str,
logger: logging.Logger) -> Optional[str]:
"""
Ensure model weights exist, download if necessary.
This function implements a caching strategy: it first checks if the checkpoint already
exists locally. If not, it downloads it from Hugging Face Hub. This means
the first run will download weights, but subsequent runs will be much faster.
Args:
model_name: Name of the model architecture
train_dataset: Training dataset type (cls, det, or ins)
metric_type: Training metric objective (composite, consistency, or accuracy)
logger: Logger instance for status messages
Returns:
Path to checkpoint if successful, None if weights cannot be obtained
"""
# Generate the expected checkpoint path
checkpoint_path = get_checkpoint_path(model_name, train_dataset, metric_type)
# First, check if we already have this checkpoint cached locally
if os.path.exists(checkpoint_path):
logger.info(f"✓ Found existing checkpoint: {checkpoint_path}")
return checkpoint_path
# Checkpoint not found locally, so we need to download it
logger.info(f"Checkpoint not found at {checkpoint_path}")
# Verify this model configuration is supported
if metric_type not in MODEL_FILENAMES:
logger.error(f"✗ Metric type '{metric_type}' not recognized")
logger.error(f" Available metric types: {list(MODEL_FILENAMES.keys())}")
return None
if train_dataset not in MODEL_FILENAMES[metric_type]:
logger.error(f"✗ Train dataset '{train_dataset}' not available for metric type '{metric_type}'")
return None
if model_name not in MODEL_FILENAMES[metric_type][train_dataset]:
available_models = get_available_models(train_dataset, metric_type)
logger.error(f"✗ Model '{model_name}' not available for {train_dataset}/{metric_type}")
logger.error(f" Available models: {available_models}")
return None
filename = MODEL_FILENAMES[metric_type][train_dataset][model_name]
logger.info(
f"Attempting to download checkpoint from Hugging Face: "
f"repo={HF_REPO_ID}, file={filename}, rev={HF_REVISION}"
)
try:
local_path = ensure_checkpoint_from_hf(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=str(Path("models") / "checkpoints" / f"{metric_type}_metric"),
revision=HF_REVISION,
)
logger.info("✓ Successfully downloaded checkpoint from Hugging Face")
return local_path
except Exception as e:
logger.error(f"✗ Failed to download checkpoint from Hugging Face: {e}")
return None
def load_model_weights(model: torch.nn.Module, weights_path: str, args: argparse.Namespace,
logger: logging.Logger) -> bool:
"""
Load model weights from checkpoint file.
This function handles the actual loading of weights into the model, with proper error
handling and support for different checkpoint formats (direct state dict or wrapped
in a dictionary with metadata).
Args:
model: The model to load weights into
weights_path: Path to the checkpoint file
args: Command line arguments
logger: Logger instance
Returns:
True if weights loaded successfully, False otherwise
"""
if not os.path.isfile(weights_path):
logger.error(f"✗ Checkpoint file not found: '{weights_path}'")
return False
logger.info(f"Loading checkpoint from '{weights_path}'")
try:
# Load checkpoint to CPU first to avoid GPU memory issues
checkpoint = torch.load(weights_path, map_location="cpu")
# Extract state dict - handle different checkpoint formats
# Some checkpoints store weights directly, others wrap them in a 'state_dict' key
state_dict = checkpoint.get('state_dict', checkpoint)
# Remove 'module.' prefix if present
# This prefix is added when models are trained with DataParallel/DistributedDataParallel
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '') if k.startswith('module.') else k
new_state_dict[name] = v
# Load the processed weights into the model
model.load_state_dict(new_state_dict)
logger.info(f"✓ Successfully loaded checkpoint")
# Log additional useful information from the checkpoint if available
if 'epoch' in checkpoint:
logger.info(f" Checkpoint epoch: {checkpoint['epoch']}")
if 'best_srcc' in checkpoint:
logger.info(f" Best SRCC: {checkpoint['best_srcc']:.4f}")
if 'metric_type' in checkpoint:
logger.info(f" Metric type: {checkpoint['metric_type']}")
return True
except Exception as e:
logger.error(f"✗ Error loading checkpoint: {str(e)}")
return False
def create_model(model_name: str, args: argparse.Namespace, logger: logging.Logger) -> torch.nn.Module:
"""
Create model instance based on model name.
This function handles the instantiation of different model architectures. It includes
special handling for the RegionVisionTransformer (RA_MIQA) which has a different
initialization process than standard vision models.
Args:
model_name: Name of the model architecture
args: Command line arguments
logger: Logger instance
Returns:
Initialized model (without loaded weights yet)
"""
# Special handling for our custom RegionVisionTransformer architecture
if model_name == 'ra_miqa':
logger.info(f"Creating RA_MIQA Model")
model = RegionVisionTransformer(
base_model_name='vit_small_patch16_224',
pretrained=True,
mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth'
)
else:
# For standard architectures, try PyTorch hub first, then fall back to timm
try:
logger.info(f"Creating model from PyTorch: {model_name}")
model = get_torch_model(model_name=model_name, pretrained=False, num_classes=1)
except Exception as e:
logger.info(f"PyTorch model not found, trying timm library: {model_name}")
try:
model = get_timm_model(model_name=model_name, pretrained=False, num_classes=1)
except Exception as e:
logger.error(f"✗ Failed to create model: {str(e)}")
raise
return model
@torch.no_grad()
def inference(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module,
args: argparse.Namespace, criterion: torch.nn.Module,
logger: logging.Logger) -> Dict:
"""
Run inference on validation set and compute metrics.
This function performs the actual evaluation of the model on the test dataset. It runs
in evaluation mode with no gradient computation, processes all batches, and computes
standard image quality assessment metrics (SRCC, PLCC, KLCC).
Args:
val_loader: DataLoader for validation data
model: Model to evaluate
args: Command line arguments
criterion: Loss function (MSE)
logger: Logger instance
Returns:
Dictionary containing predictions, ground truth, and computed metrics
"""
# Set model to evaluation mode - this disables dropout and uses running stats for batchnorm
model.eval()
val_dataset_len = len(val_loader.dataset)
val_loader_len = len(val_loader)
# Initialize tracking variables for performance monitoring
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
# Storage lists for accumulating results across all batches
temp_pred_scores = []
temp_gt_scores = []
temp_img_names = []
logger.info(f"Starting inference on {val_dataset_len} images...")
for i, batch in enumerate(val_loader):
# Move data to GPU if available
image_cropped = batch['image_cropped'].cuda(args.gpu, non_blocking=True)
image_resized = batch['image_resized'].cuda(args.gpu, non_blocking=True)
target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1)
# Forward pass - compute predictions
output = model(image_cropped, image_resized)
loss = criterion(output.view(-1), target.view(-1))
losses.update(loss.item(), target.size(0))
# Accumulate results for later metric computation
temp_pred_scores.append(output.view(-1))
temp_gt_scores.append(target.view(-1))
temp_img_names.extend(batch['image_name'])
# Log progress periodically
if i % args.print_freq == 0:
logger.info(
f" [{i}/{val_loader_len}] "
f"Loss: {losses.val:.4f} (avg: {losses.avg:.4f})"
)
# Concatenate all batch results into single tensors
final_preds = torch.cat(temp_pred_scores)
final_grotruth = torch.cat(temp_gt_scores)
# Handle patch-based predictions if the model uses multiple patches per image
if hasattr(args, 'patch_num') and args.patch_num > 1:
logger.info(f"Averaging predictions over {args.patch_num} patches per image")
preds_matrix = final_preds.view(-1, args.patch_num)
final_preds = preds_matrix.mean(dim=-1).squeeze()
final_grotruth = final_grotruth.view(-1, args.patch_num).mean(dim=-1).squeeze()
logger.info(
f"Dataset size: {val_dataset_len}, "
f"Predictions shape: {final_preds.shape}, "
f"Ground truth shape: {final_grotruth.shape}"
)
# Sanity check for invalid values that would corrupt metric computation
if torch.isnan(final_preds).any() or torch.isinf(final_preds).any():
raise ValueError("Found NaN or inf values in predictions")
if torch.isnan(final_grotruth).any() or torch.isinf(final_grotruth).any():
raise ValueError("Found NaN or inf values in ground truth")
# Compute standard image quality assessment metrics
# SRCC: Spearman's rank correlation coefficient - measures monotonic relationship
test_srcc = torchmetrics.functional.spearman_corrcoef(final_preds, final_grotruth).item()
# PLCC: Pearson's linear correlation coefficient - measures linear relationship
test_plcc = torchmetrics.functional.pearson_corrcoef(final_preds, final_grotruth).item()
# KLCC: Kendall's rank correlation coefficient - another rank-based metric
test_klcc = torchmetrics.functional.kendall_rank_corrcoef(final_preds, final_grotruth).item()
# Package all results into a dictionary for return
results = {
'image_names': temp_img_names,
'predictions': final_preds.cpu().numpy().tolist(),
'ground_truth': final_grotruth.cpu().numpy().tolist(),
'metrics': {
'srcc': test_srcc,
'plcc': test_plcc,
'klcc': test_klcc,
'loss': losses.avg
}
}
return results
def save_results(results: Dict, model_name: str, train_dataset: str,
test_dataset: str, metric_type: str, output_dir: str,
logger: logging.Logger) -> None:
"""
Save inference results to CSV file with detailed metrics.
This function saves both detailed per-image results and prints a summary of the
overall performance metrics. The filename includes all relevant configuration
details for easy identification.
Args:
results: Results dictionary from inference
model_name: Name of the model
train_dataset: Training dataset type
test_dataset: Test dataset name
metric_type: Training metric objective
output_dir: Base directory to save results
logger: Logger instance
"""
# Create the evaluations subdirectory
eval_dir = os.path.join(output_dir, 'evaluations')
os.makedirs(eval_dir, exist_ok=True)
# Prepare detailed per-image results with predictions and errors
csv_data = []
for img_name, pred, gt in zip(results['image_names'],
results['predictions'],
results['ground_truth']):
csv_data.append({
'image_name': img_name,
'prediction': pred,
'ground_truth': gt,
'absolute_error': abs(pred - gt)
})
# Create descriptive filename that includes all configuration details
csv_filename = f"{model_name}_{train_dataset}_{metric_type}_on_{test_dataset}.csv"
csv_path = os.path.join(eval_dir, csv_filename)
# Save detailed results to CSV
df = pd.DataFrame(csv_data)
df.to_csv(csv_path, index=False)
logger.info(f"Detailed results saved to: {csv_path}")
# Print formatted metrics summary to console and log
logger.info("\n" + "=" * 70)
logger.info("EVALUATION METRICS")
logger.info("=" * 70)
logger.info(f"Model: {model_name}")
logger.info(f"Trained on: {train_dataset}")
logger.info(f"Metric type: {metric_type}")
logger.info(f"Tested on: {test_dataset}")
logger.info("-" * 70)
logger.info(f"SRCC (Spearman): {results['metrics']['srcc']:.4f}")
logger.info(f"PLCC (Pearson): {results['metrics']['plcc']:.4f}")
logger.info(f"KLCC (Kendall): {results['metrics']['klcc']:.4f}")
logger.info(f"MSE Loss: {results['metrics']['loss']:.4f}")
logger.info("=" * 70 + "\n")
def main(args: argparse.Namespace, logger: logging.Logger) -> None:
"""
Main inference pipeline orchestrating all steps.
This function coordinates the entire evaluation process: validating inputs,
ensuring model weights are available, loading data, creating and loading the model,
running inference, and saving results.
Args:
args: Command line arguments
logger: Logger instance
"""
# Validate required arguments
if not args.model_name:
raise ValueError("Please specify --model_name")
if not args.train_dataset:
raise ValueError("Please specify --train_dataset (cls, det, or ins)")
if not args.test_dataset:
raise ValueError("Please specify --test_dataset")
if not args.metric_type:
raise ValueError("Please specify --metric_type (composite, consistency, or accuracy)")
logger.info(f"\nStarting MIQA Inference Pipeline")
logger.info(f"Model: {args.model_name}")
logger.info(f"Trained on: {args.train_dataset}")
logger.info(f"Metric type: {args.metric_type}")
logger.info(f"Testing on: {args.test_dataset}")
# Ensure model weights are available (download if necessary)
checkpoint_path = ensure_model_weights(args.model_name, args.train_dataset,
args.metric_type, logger)
if checkpoint_path is None:
logger.error("Cannot proceed without model weights")
return
# Build dataset and dataloader
logger.info(f"\nLoading {args.test_dataset} dataset...")
args.dataset = args.test_dataset # Set dataset name for dataloader builder
args.eval_only = True # Indicate evaluation mode
val_dataset = build_dataloader.build_dataset(args)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True
)
logger.info(f"✓ Loaded {len(val_dataset)} images with {args.workers} workers")
# Create model architecture
logger.info(f"\nCreating model architecture...")
args.arch = args.model_name
model = create_model(args.model_name, args, logger)
# Load pre-trained weights into model
if not load_model_weights(model, checkpoint_path, args, logger):
logger.error("Failed to load model weights")
return
# Move model to GPU if available
if args.gpu is not None and torch.cuda.is_available():
model = model.cuda(args.gpu)
logger.info(f"✓ Model moved to GPU {args.gpu}")
else:
logger.warning("GPU not available, using CPU (this will be slower)")
# Create loss function for evaluation
criterion = torch.nn.MSELoss()
# Run inference on the test set
logger.info(f"\nRunning inference...")
results = inference(val_loader, model, args, criterion, logger)
# Save results and print summary
save_results(results, args.model_name, args.train_dataset,
args.test_dataset, args.metric_type, args.output_dir, logger)
if __name__ == '__main__':
# Parse command line arguments
parser = get_args()
parser.add_argument('--model_name', type=str, required=True,
choices=['ra_miqa'],
help='Model architecture (Hub registry currently ships RA-MIQA only)')
parser.add_argument('--train_dataset', type=str, required=True,
choices=['cls', 'det', 'ins'],
help='Dataset type the model was trained on (cls=classification, det=detection, ins=instance)')
parser.add_argument('--test_dataset', type=str, required=True,
help='Name of the dataset to test on')
parser.add_argument('--metric_type', type=str, required=True,
choices=['composite', 'consistency', 'accuracy'],
help='Training metric objective used (composite=both metrics, consistency=consistency-focused, accuracy=accuracy-focused)')
parser.add_argument('--output_dir', type=str, default='outputs',
help='Directory to save results (default: outputs)')
args = parser.parse_args()
# Create output directory structure
os.makedirs(args.output_dir, exist_ok=True)
# Configure logging to both file and console
log_filename = f"inference_{args.model_name}_{args.train_dataset}_{args.metric_type}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.join(args.output_dir, log_filename)),
logging.StreamHandler()
]
)
logger = logging.getLogger('miqa_inference')
# Run main inference pipeline with error handling
try:
main(args, logger)
except Exception as e:
logger.error(f"Error during inference: {str(e)}", exc_info=True)
raise