| from typing import Any |
| import torch |
| from torch import nn, optim |
| import lightning.pytorch as pl |
| import torchvision.models.video as tvmv |
| import sklearn.metrics as skm |
| import numpy as np |
|
|
|
|
| class SyntaxLightningModule(pl.LightningModule): |
| """ |
| LightningModule для обучения 3D-ResNet (r3d_18) как backbone |
| в задаче предсказания SYNTAX score по видеоангиографии. |
| |
| Модель предсказывает: |
| - yp_clf: вероятность поражения (syntax > порог) — бинарная классификация |
| - yp_reg: логарифмированное значение SYNTAX — регрессия |
| """ |
|
|
| def __init__( |
| self, |
| num_classes: int, |
| lr: float, |
| weight_decay: float = 0.0, |
| max_epochs: int = None, |
| weight_path: str = None, |
| sigma_a: float = 0.0, |
| sigma_b: float = 1.0, |
| **kwargs, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.num_classes = num_classes |
| self.lr = lr |
| self.weight_decay = weight_decay |
| self.max_epochs = max_epochs |
| self.weight_path = weight_path |
| self.sigma_a = sigma_a |
| self.sigma_b = sigma_b |
|
|
| |
| self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT) |
|
|
| |
| |
| in_features = self.model.fc.in_features |
| self.model.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True) |
|
|
| |
| if self.weight_path is not None: |
| ckpt = torch.load(self.weight_path, map_location="cpu", weights_only=False) |
| state_dict = ckpt["state_dict"] |
| |
| new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} |
| self.model.load_state_dict(new_state_dict, strict=False) |
|
|
| |
| self.loss_clf = nn.BCEWithLogitsLoss(reduction="none") |
| self.loss_reg = nn.MSELoss(reduction="none") |
|
|
| |
| self.y_val = [] |
| self.p_val = [] |
| self.r_val = [] |
| self.ty_val = [] |
| self.tp_val = [] |
|
|
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.model(x) |
|
|
| |
| def training_step(self, batch, batch_idx): |
| """ |
| Один шаг обучения: |
| - бинарная классификация поражения (BCE с down-weight для нулей); |
| - регрессия логарифмированного SYNTAX с учётом get_sigma(target). |
| """ |
| x, y, target, sample_weight, path, original_label = batch |
|
|
| y_hat = self(x) |
| yp_clf = y_hat[:, 0:1] |
| yp_reg = y_hat[:, 1:] |
|
|
| |
| weights_clf = torch.where(y > 0, 1.0, 0.45) |
| clf_loss = self.loss_clf(yp_clf, y) |
| clf_loss = (clf_loss * weights_clf).mean() |
|
|
| |
| reg_loss_raw = self.loss_reg(yp_reg, target) |
| sigma = self.sigma_a * target + self.sigma_b |
| reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
|
|
| loss = clf_loss + 0.5 * reg_loss |
|
|
| |
| y_pred = torch.sigmoid(yp_clf) |
| y_bin = torch.round(y.detach().cpu()).int() |
| y_pred_bin = torch.round(y_pred.detach().cpu()).int() |
|
|
| self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True) |
| self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True) |
| self.log("train_full_loss", loss, prog_bar=True, sync_dist=True) |
| self.log( |
| "train_f1", |
| skm.f1_score(y_bin, y_pred_bin, zero_division=0), |
| prog_bar=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "train_acc", |
| skm.accuracy_score(y_bin, y_pred_bin), |
| prog_bar=True, |
| sync_dist=True, |
| ) |
|
|
| return loss |
|
|
| |
| def validation_step(self, batch, batch_idx): |
| """ |
| Валидационный шаг: считаем тот же комбинированный лосс и |
| аккумулируем предсказания для расчёта метрик на эпоху. |
| """ |
| x, y, target, sample_weight, path, original_label = batch |
|
|
| y_hat = self(x) |
| yp_clf = y_hat[:, 0:1] |
| yp_reg = y_hat[:, 1:] |
|
|
| |
| clf_loss = self.loss_clf(yp_clf, y) |
| reg_loss_raw = self.loss_reg(yp_reg, target) |
| sigma = self.sigma_a * target + self.sigma_b |
| reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
| loss = clf_loss.mean() + 0.5 * reg_loss |
|
|
| |
| y_pred = torch.sigmoid(yp_clf) |
|
|
| self.y_val.append(int(y[..., 0].cpu())) |
| self.p_val.append(float(y_pred[..., 0].cpu())) |
| self.r_val.append(round(float(y_pred[..., 0].cpu()))) |
|
|
| self.ty_val.append(float(target[..., 0].cpu())) |
| self.tp_val.append(float(yp_reg[..., 0].cpu())) |
|
|
| return loss |
|
|
| |
| def on_validation_epoch_end(self) -> None: |
| """ |
| Подсчёт валидационных метрик по всей эпохе и логирование в Logger. |
| """ |
| try: |
| auc = skm.roc_auc_score(self.y_val, self.p_val) |
| f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0) |
| acc = skm.accuracy_score(self.y_val, self.r_val) |
| mae = skm.mean_absolute_error(self.y_val, self.r_val) |
| rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val) |
|
|
| self.log("val_auc", auc, prog_bar=True, sync_dist=True) |
| self.log("val_f1", f1, prog_bar=True, sync_dist=True) |
| self.log("val_acc", acc, prog_bar=True, sync_dist=True) |
| self.log("val_mae", mae, prog_bar=True, sync_dist=True) |
| self.log("val_rmse", rmse, prog_bar=True, sync_dist=True) |
|
|
| except ValueError as err: |
| |
| print(err) |
| print("Y_VAL", self.y_val) |
| print("P_VAL", self.p_val) |
|
|
| |
| self.y_val.clear() |
| self.p_val.clear() |
| self.r_val.clear() |
| self.ty_val.clear() |
| self.tp_val.clear() |
|
|
| |
| def on_train_epoch_end(self) -> None: |
| """Логирование текущего learning rate.""" |
| opt = self.optimizers() |
| if hasattr(opt, "optimizer"): |
| lr = opt.optimizer.param_groups[0]["lr"] |
| else: |
| lr = opt.param_groups[0]["lr"] |
| self.log("lr", lr, on_step=False, on_epoch=True, sync_dist=True) |
|
|
| |
| def configure_optimizers(self): |
| """ |
| - Если weight_path не задан → pretrain: обучаем только финальный fc-слой. |
| - Если weight_path задан → full fine-tuning: обучаем весь backbone. |
| """ |
| if not self.weight_path: |
| |
| for param in self.parameters(): |
| param.requires_grad = False |
| for p in self.model.fc.parameters(): |
| p.requires_grad = True |
| params = list(self.model.fc.parameters()) |
| else: |
| |
| for param in self.parameters(): |
| param.requires_grad = True |
| params = self.parameters() |
|
|
| optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) |
|
|
| if self.max_epochs is not None: |
| scheduler = optim.lr_scheduler.OneCycleLR( |
| optimizer=optimizer, |
| max_lr=self.lr, |
| total_steps=self.max_epochs, |
| ) |
| return [optimizer], [scheduler] |
| else: |
| return optimizer |
|
|
| |
| def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: |
| """ |
| Инференс: возвращает словарь с бинарным предсказанием, вероятностями |
| и регрессионным выходом. |
| """ |
| x, y, target, sample_weight, path, original_label = batch |
| y_hat = self(x) |
| yp_clf = y_hat[:, 0:1] |
| yp_reg = y_hat[:, 1:] |
| y_prob = torch.sigmoid(yp_clf) |
| y_pred = torch.round(y_prob) |
|
|
| return { |
| "y": y, |
| "y_pred": y_pred, |
| "y_prob": y_prob, |
| "y_reg": yp_reg, |
| "target": target, |
| "original_label": original_label, |
| } |
|
|