| | from utils import CustomDataset, transform, Convert_ONNX |
| | from torch.utils.data import Dataset, DataLoader |
| | from utils import CustomDataset, TestingDataset, transform |
| | from tqdm import tqdm |
| | import torch |
| | import numpy as np |
| | from resnet_model_mask import ResidualBlock, ResNet |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from tqdm import tqdm |
| | import torch.nn.functional as F |
| | from torch.optim.lr_scheduler import ReduceLROnPlateau |
| | import pickle |
| | import matplotlib.pyplot as plt |
| | import pandas as pd |
| |
|
| | torch.manual_seed(1) |
| | |
| |
|
| |
|
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | num_gpus = torch.cuda.device_count() |
| | print(num_gpus) |
| |
|
| | test_data_dir = '/mnt/buf1/pma/frbnn/test_ready' |
| | test_dataset = TestingDataset(test_data_dir, transform=transform) |
| |
|
| | num_classes = 2 |
| | testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32) |
| |
|
| | model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
| | model = nn.DataParallel(model) |
| | model = model.to(device) |
| | params = sum(p.numel() for p in model.parameters()) |
| | print("num params ",params) |
| |
|
| | model_1 = 'models_mask/model-43-99.235_42.pt' |
| | |
| | model.load_state_dict(torch.load(model_1, weights_only=True)) |
| | model = model.eval() |
| |
|
| | |
| | val_loss = 0.0 |
| | correct_valid = 0 |
| | total = 0 |
| | results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]} |
| | model.eval() |
| | with torch.no_grad(): |
| | for images, labels in tqdm(testloader): |
| | inputs, labels = images.to(device), labels |
| | outputs = model(inputs, return_mask = True) |
| | _, predicted = torch.max(outputs, 1) |
| | results['output'].extend(outputs.cpu().numpy().tolist()) |
| | results['pred'].extend(predicted.cpu().numpy().tolist()) |
| | results['true'].extend(labels[0].cpu().numpy().tolist()) |
| | results['freq'].extend(labels[2].cpu().numpy().tolist()) |
| | results['dm'].extend(labels[1].cpu().numpy().tolist()) |
| | results['snr'].extend(labels[3].cpu().numpy().tolist()) |
| | results['boxcar'].extend(labels[4].cpu().numpy().tolist()) |
| | total += labels[0].size(0) |
| | correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item() |
| | |
| | |
| | val_accuracy = correct_valid / total * 100.0 |
| | print("===========================") |
| | print('accuracy: ', val_accuracy) |
| | print("===========================") |
| |
|
| | import pickle |
| |
|
| | |
| | with open('models_mask/test_42.pkl', 'wb') as f: |
| | pickle.dump(results, f) |
| |
|
| | from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix |
| |
|
| | |
| | true = results['true'] |
| | pred = results['pred'] |
| |
|
| | |
| | precision = precision_score(true, pred) |
| | recall = recall_score(true, pred) |
| | f1 = f1_score(true, pred) |
| | |
| | tn, fp, fn, tp = confusion_matrix(true, pred).ravel() |
| |
|
| | |
| | fpr = fp / (fp + tn) |
| |
|
| | print(f"False Positive Rate: {fpr:.3f}") |
| |
|
| | print(f"Precision: {precision:.3f}") |
| | print(f"Recall: {recall:.3f}") |
| | print(f"F1 Score: {f1:.3f}") |
| |
|
| | |
| | |
| | df = pd.DataFrame({ |
| | 'dm': results['dm'], |
| | 'true': results['true'], |
| | 'pred': results['pred'], |
| | 'snr': results['snr'], |
| | 'freq': results['freq'], |
| | 'boxcar': np.array(results['boxcar'])/2 |
| | }) |
| |
|
| | |
| | df = df[df['true'] == 1].copy() |
| |
|
| | print(f"Filtered to {len(df)} samples with true label = 1") |
| |
|
| | |
| | dm_bins = np.linspace(df['dm'].min(), df['dm'].max(), 20) |
| | df['dm_bin'] = pd.cut(df['dm'], bins=dm_bins, include_lowest=True) |
| | print('min boxcar',df['boxcar'].min()) |
| | |
| | def calculate_accuracy_with_uncertainty(group): |
| | correct = (group['true'] == group['pred']).sum() |
| | total = len(group) |
| | accuracy = correct / total * 100 |
| | |
| | p = correct / total |
| | se = np.sqrt(p * (1 - p) / total) * 100 |
| | return pd.Series({'accuracy': accuracy, 'std_error': se, 'n_samples': total}) |
| |
|
| | dm_accuracy = df.groupby('dm_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
| |
|
| | |
| | dm_accuracy['dm_midpoint'] = dm_accuracy['dm_bin'].apply(lambda x: x.mid) |
| |
|
| | |
| | plt.figure(figsize=(10, 6)) |
| | ax1 = plt.gca() |
| | ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], |
| | yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| | capsize=5, capthick=2, elinewidth=1) |
| | ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) |
| | ax1.set_ylabel('Accuracy (%)', fontsize=16) |
| | ax1.set_title('Accuracy vs Dispersion Measure', fontsize=18) |
| | ax1.grid(True, alpha=0.3) |
| | ax1.set_ylim(97, 100) |
| | ax1.tick_params(axis='both', which='major', labelsize=14) |
| |
|
| | |
| | yticks = ax1.get_yticks() |
| | ax1.set_yticks([tick for tick in yticks if tick <= 100]) |
| |
|
| | |
| | ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| | label=f'Overall: {val_accuracy:.2f}%') |
| | ax1.legend(fontsize=14) |
| |
|
| | |
| | ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | plt.savefig('models_mask/accuracy_vs_dm.pdf', dpi=300, bbox_inches='tight') |
| | plt.show() |
| |
|
| | |
| | |
| | df_snr_filtered = df[df['snr'] > 0].copy() |
| |
|
| | |
| | snr_bins = np.linspace(df_snr_filtered['snr'].min(), df_snr_filtered['snr'].max(), 20) |
| | df_snr_filtered['snr_bin'] = pd.cut(df_snr_filtered['snr'], bins=snr_bins, include_lowest=True) |
| |
|
| | |
| | snr_accuracy = df_snr_filtered.groupby('snr_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
| |
|
| | |
| | snr_accuracy['snr_midpoint'] = snr_accuracy['snr_bin'].apply(lambda x: x.mid) |
| |
|
| | |
| | plt.figure(figsize=(10, 6)) |
| | ax2 = plt.gca() |
| | ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], |
| | yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| | capsize=5, capthick=2, elinewidth=1) |
| | ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) |
| | ax2.set_ylabel('Accuracy (%)', fontsize=16) |
| | ax2.set_title('Accuracy vs SNR', fontsize=18) |
| | ax2.grid(True, alpha=0.3) |
| | ax2.set_ylim(80, 100) |
| | ax2.tick_params(axis='both', which='major', labelsize=14) |
| |
|
| | |
| | yticks = ax2.get_yticks() |
| | ax2.set_yticks([tick for tick in yticks if tick <= 100]) |
| |
|
| | |
| | ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| | label=f'Overall: {val_accuracy:.2f}%') |
| | ax2.legend(fontsize=14) |
| |
|
| | |
| | ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | plt.savefig('models_mask/accuracy_vs_snr.pdf', dpi=300, bbox_inches='tight') |
| | plt.show() |
| |
|
| | |
| | |
| | |
| | |
| | df_boxcar_filtered = df[df['boxcar'] > 0].copy() |
| | df_boxcar_filtered['boxcar_bin'] = pd.qcut(df_boxcar_filtered['boxcar'], q=20, duplicates='drop') |
| |
|
| | |
| | boxcar_accuracy = df_boxcar_filtered.groupby('boxcar_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
| |
|
| | |
| | boxcar_accuracy['boxcar_midpoint'] = boxcar_accuracy['boxcar_bin'].apply(lambda x: x.mid) |
| |
|
| | |
| | plt.figure(figsize=(10, 6)) |
| | ax3 = plt.gca() |
| | ax3.errorbar(boxcar_accuracy['boxcar_midpoint'], boxcar_accuracy['accuracy'], |
| | yerr=boxcar_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| | capsize=5, capthick=2, elinewidth=1) |
| | ax3.set_xscale('log') |
| | ax3.set_xlabel('Boxcar Width (log scale)', fontsize=16) |
| | |
| | ax3.grid(True, alpha=0.3) |
| | ax3.set_ylim(0, 100) |
| | ax3.tick_params(axis='both', which='major', labelsize=14) |
| |
|
| | |
| | yticks = ax3.get_yticks() |
| | ax3.set_yticks([tick for tick in yticks if tick <= 100]) |
| |
|
| | |
| | ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| | label=f'Overall: {val_accuracy:.2f}%') |
| | ax3.legend(fontsize=14) |
| |
|
| | |
| | ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | plt.savefig('models_mask/accuracy_vs_boxcar.pdf', dpi=300, bbox_inches='tight') |
| | plt.show() |
| |
|
| |
|
| | print(f"Plots saved to models_mask/accuracy_vs_dm.pdf, models_mask/accuracy_vs_snr.pdf, and models_mask/accuracy_vs_boxcar.pdf") |
| |
|
| | |
| | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) |
| |
|
| | |
| | ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], |
| | yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| | capsize=5, capthick=2, elinewidth=1) |
| | ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) |
| | ax1.set_ylabel('Accuracy (%)', fontsize=16) |
| | |
| | ax1.grid(True, alpha=0.3) |
| | ax1.set_ylim(97, 100.5) |
| | ax1.tick_params(axis='both', which='major', labelsize=14) |
| |
|
| | |
| | yticks = ax1.get_yticks() |
| | ax1.set_yticks([tick for tick in yticks if tick <= 100]) |
| |
|
| | ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| | label=f'Overall: {val_accuracy:.2f}%') |
| | ax1.legend(fontsize=14) |
| |
|
| | |
| | ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], |
| | yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| | capsize=5, capthick=2, elinewidth=1) |
| | ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) |
| | |
| | ax2.grid(True, alpha=0.3) |
| | ax2.set_ylim(88, 100.5) |
| | ax2.tick_params(axis='both', which='major', labelsize=14) |
| |
|
| | |
| | yticks = ax2.get_yticks() |
| | ax2.set_yticks([tick for tick in yticks if tick <= 100]) |
| |
|
| | ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| | label=f'Overall: {val_accuracy:.2f}%') |
| | ax2.legend(fontsize=14) |
| |
|
| | |
| | ax3.errorbar(boxcar_accuracy['boxcar_midpoint'][:-1], |
| | boxcar_accuracy['accuracy'][:-1], |
| | yerr=boxcar_accuracy['std_error'][:-1], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| | capsize=5, capthick=2, elinewidth=1) |
| | ax3.set_xscale('log') |
| | ax3.set_xlabel('Boxcar Width (log scale) [s]', fontsize=16) |
| | |
| | ax3.grid(True, alpha=0.3) |
| | ax3.set_ylim(96, 100.5) |
| | ax3.tick_params(axis='both', which='major', labelsize=14) |
| |
|
| | |
| | yticks = ax3.get_yticks() |
| | ax3.set_yticks([tick for tick in yticks if tick <= 100]) |
| |
|
| | ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| | label=f'Overall: {val_accuracy:.2f}%') |
| | ax3.legend(fontsize=14) |
| |
|
| | |
| | ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') |
| | ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') |
| | ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | plt.savefig('models_mask/accuracy_vs_all_parameters.pdf', |
| | dpi=300, bbox_inches='tight', |
| | pad_inches=0.1, format='pdf') |
| | plt.show() |
| |
|
| | print(f"Combined plot saved to models_mask/accuracy_vs_all_parameters.pdf") |
| |
|