| """ |
| Визуализация предсказаний SYNTAX: |
| - точки (SYNTAX GT vs предсказания модели) для нескольких датасетов; |
| - зоны риска (низкий / высокий риск); |
| - области ±σ и ±2σ вокруг диагонали; |
| - логистические тренды для каждого датасета. |
| |
| Скрипт не зависит от PyTorch/Lightning и используется на этапе инференса. |
| Сохранение осуществляется в папку `visualizations/` внутри проекта. |
| """ |
|
|
| import os |
| import numpy as np |
| import plotly.graph_objects as go |
| from scipy.optimize import curve_fit |
|
|
|
|
| def visualize_final_syntax_plotly_multi( |
| datasets, |
| r2_values, |
| gt_row, |
| postfix=None, |
| threshold=22.0, |
| recall_values=None, |
| backbone=False, |
| ): |
| """ |
| Единая визуализация SYNTAX: точки, зоны риска и логистические тренды. |
| |
| Параметры |
| --------- |
| datasets : dict[str, tuple[list[float], list[float]]] |
| Словарь {имя_датасета: (syntax_true_list, syntax_pred_list)}. |
| r2_values : dict[str, float] |
| Словарь R^2 по датасетам. |
| gt_row : str |
| Строка, попадающая в заголовок (например, "ENSEMBLE" или "BOTH"). |
| postfix : str | None |
| Суффикс для имени сохраняемого файла. |
| threshold : float |
| Порог SYNTAX (обычно 22.0) для разделения зон риска. |
| recall_values : dict[str, float] | None |
| Словарь Recall по датасетам (может быть None). |
| backbone : bool |
| Если True, сохраняет в `visualizations/backbone`, иначе в `visualizations/`. |
| """ |
| |
| DATA_MIN = 0.0 |
| DATA_MAX = 60.0 |
|
|
| PADDING = 0.5 |
|
|
| SIGMA_SLOPE = 0.15 |
| SIGMA_BASE = 1.4 |
|
|
| PLOT_WIDTH = 980 |
| PLOT_HEIGHT = 980 |
|
|
| BASE_FONT_SIZE = 16 |
| TITLE_FONT_SIZE = 22 |
| AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE |
| AXIS_TICK_FONT_SIZE = 15 |
| LEGEND_FONT_SIZE = 14 |
|
|
| MARKER_SIZE = 11 |
| MARKER_LINE_WIDTH = 1.1 |
| LINE_WIDTH = 2 |
| TREND_LINE_WIDTH = 3 |
|
|
| PLOT_BG_COLOR = "rgba(235,238,245,1)" |
| PAPER_BG_COLOR = "white" |
| LEGEND_BG_COLOR = "rgba(255,255,255,0.94)" |
| GRID_COLOR = "rgba(100,116,139,0.18)" |
|
|
| MARGIN_LEFT = 70 |
| MARGIN_RIGHT = 24 |
| MARGIN_TOP = 78 |
| MARGIN_BOTTOM = 70 |
|
|
| LEGEND_X = 0.04 |
| LEGEND_Y = 0.99 |
|
|
| COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"] |
| SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"] |
|
|
| SIGMA_POINTS = 400 |
| TREND_POINTS = 500 |
|
|
| |
|
|
| def _logistic_time(t, R0, Rmax, t50, k): |
| """Логистическая функция по времени/оценке SYNTAX.""" |
| t = np.asarray(t, dtype=float) |
| t_safe = np.where(t <= 0, 1e-3, t) |
| return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k) |
|
|
| def _fit_logistic(x, y, domain=(DATA_MIN, DATA_MAX), n=TREND_POINTS): |
| """ |
| Аппроксимация логистической кривой. |
| Возвращает X, Y или (None, None), если фит не удался. |
| """ |
| x = np.asarray(x, dtype=float) |
| y = np.asarray(y, dtype=float) |
| m = np.isfinite(x) & np.isfinite(y) |
| if m.sum() < 4: |
| return None, None |
|
|
| x_m, y_m = x[m], y[m] |
| x_min = max(float(np.min(x_m)), float(domain[0])) |
| x_max = min(float(np.max(x_m)), float(domain[1])) |
| if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min: |
| return None, None |
|
|
| x_pos = x_m[x_m > 0] |
| if x_pos.size == 0: |
| return None, None |
|
|
| R0_init = float(np.percentile(y_m, 10)) |
| Rmax_init = float(np.percentile(y_m, 90)) |
| t50_init = float(np.median(x_pos)) |
| k_init = 1.0 |
|
|
| lower = [-10.0, 0.0, 1e-3, 0.01] |
| upper = [60.0, 80.0, 60.0, 10.0] |
|
|
| try: |
| popt, _ = curve_fit( |
| _logistic_time, |
| x_m, |
| y_m, |
| p0=[R0_init, Rmax_init, t50_init, k_init], |
| bounds=(lower, upper), |
| maxfev=20000, |
| ) |
| except Exception: |
| return None, None |
|
|
| X = np.linspace(x_min, x_max, n) |
| Y = _logistic_time(X, *popt) |
| return X, Y |
|
|
| |
| fig = go.Figure() |
|
|
| line_min = DATA_MIN - PADDING |
| line_max = DATA_MAX + PADDING |
| domain = (line_min, line_max) |
|
|
| base_font = dict( |
| family="Inter, Roboto, Helvetica Neue, Arial, sans-serif", |
| size=BASE_FONT_SIZE, |
| ) |
|
|
| |
| fig.add_trace( |
| go.Scatter( |
| x=[line_min, threshold, threshold, line_min], |
| y=[line_min, line_min, threshold, threshold], |
| fill="toself", |
| fillcolor="rgba(255, 82, 82, 0.12)", |
| line=dict(color="rgba(0,0,0,0)"), |
| name="Low-risk zone", |
| legendgroup="zones", |
| legendgrouptitle_text="Пороги и линии", |
| showlegend=True, |
| hoverinfo="skip", |
| legendrank=0, |
| ) |
| ) |
| fig.add_trace( |
| go.Scatter( |
| x=[threshold, line_max, line_max, threshold], |
| y=[threshold, threshold, line_max, line_max], |
| fill="toself", |
| fillcolor="rgba(76, 175, 80, 0.14)", |
| line=dict(color="rgba(0,0,0,0)"), |
| name="High-risk zone", |
| legendgroup="zones", |
| showlegend=True, |
| hoverinfo="skip", |
| legendrank=0, |
| ) |
| ) |
|
|
| fig.add_trace( |
| go.Scatter( |
| x=[threshold, threshold, None, line_min, line_max], |
| y=[line_min, line_max, None, threshold, threshold], |
| mode="lines", |
| name=rf"$\mathrm{{SYNTAX}}={threshold}$", |
| legendgroup="zones", |
| showlegend=True, |
| line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"), |
| legendrank=0, |
| hoverinfo="skip", |
| ) |
| ) |
|
|
| x_vals = np.linspace(line_min, line_max, SIGMA_POINTS) |
| sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals |
| sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals |
| two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals |
| two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals |
|
|
| fig.add_trace( |
| go.Scatter( |
| x=np.concatenate([x_vals, x_vals[::-1]]), |
| y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]), |
| fill="toself", |
| fillcolor="rgba(255,193,7,0.18)", |
| line=dict(color="rgba(0,0,0,0)"), |
| name=r"$\pm 2\sigma$", |
| legendgroup="zones", |
| showlegend=True, |
| hoverinfo="skip", |
| legendrank=0, |
| ) |
| ) |
| fig.add_trace( |
| go.Scatter( |
| x=np.concatenate([x_vals, x_vals[::-1]]), |
| y=np.concatenate([sigma_lower, sigma_upper[::-1]]), |
| fill="toself", |
| fillcolor="rgba(255,152,0,0.30)", |
| line=dict(color="rgba(0,0,0,0)"), |
| name=r"$\pm \sigma$", |
| legendgroup="zones", |
| showlegend=True, |
| hoverinfo="skip", |
| legendrank=0, |
| ) |
| ) |
|
|
| fig.add_trace( |
| go.Scatter( |
| x=[line_min, line_max], |
| y=[line_min, line_max], |
| mode="lines", |
| name=r"$y=x$", |
| legendgroup="zones", |
| showlegend=True, |
| line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH), |
| legendrank=0, |
| ) |
| ) |
|
|
| |
| first_dataset = True |
| for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()): |
| x = np.array(syntax_true, dtype=float) |
| y = np.array(syntax_pred, dtype=float) |
| if x.size == 0 or y.size == 0: |
| continue |
|
|
| r2 = r2_values.get(label, None) |
| recall = recall_values.get(label, None) if recall_values else None |
| hover_lines = [f"<b>{label}</b>"] |
| if r2 is not None: |
| hover_lines.append(f"R² = {r2:.3f}") |
| if recall is not None: |
| hover_lines.append(f"Recall = {recall:.3f}") |
| hovertemplate = ( |
| "<br>".join(hover_lines) |
| + "<br>GT: %{x:.3f}<br>Pred: %{y:.3f}<extra></extra>" |
| ) |
|
|
| fig.add_trace( |
| go.Scatter( |
| x=x, |
| y=y, |
| mode="markers", |
| name=label, |
| legendgroup="datasets", |
| legendgrouptitle_text=("Датасеты" if first_dataset else None), |
| showlegend=True, |
| marker=dict( |
| color=COLORS[i % len(COLORS)], |
| size=MARKER_SIZE, |
| opacity=0.96, |
| symbol=SYMBOLS[i % len(SYMBOLS)], |
| line=dict( |
| width=MARKER_LINE_WIDTH, color="rgba(255,255,255,0.95)" |
| ), |
| ), |
| hovertemplate=hovertemplate, |
| legendrank=20, |
| ) |
| ) |
| first_dataset = False |
|
|
| |
| first_trend = True |
| for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()): |
| x = np.array(syntax_true, dtype=float) |
| y = np.array(syntax_pred, dtype=float) |
| if x.size == 0 or y.size == 0: |
| continue |
|
|
| Xc, Yc = _fit_logistic(x, y, domain=domain) |
| if Xc is not None: |
| fig.add_trace( |
| go.Scatter( |
| x=Xc, |
| y=Yc, |
| mode="lines", |
| name=label, |
| legendgroup="trends", |
| legendgrouptitle_text=( |
| "Тренды (логистические)" if first_trend else None |
| ), |
| showlegend=True, |
| line=dict( |
| color=COLORS[i % len(COLORS)], width=TREND_LINE_WIDTH |
| ), |
| hoverinfo="skip", |
| legendrank=30, |
| ) |
| ) |
| first_trend = False |
|
|
| |
| title_text = f"SYNTAX predictions ({gt_row})" |
| if postfix: |
| title_text += f" {postfix}" |
|
|
| fig.update_layout( |
| title=dict( |
| text=title_text, |
| x=0.5, |
| xanchor="center", |
| font=dict( |
| size=TITLE_FONT_SIZE, |
| family=base_font["family"], |
| color="rgba(15,23,42,1)", |
| ), |
| ), |
| font=base_font, |
| xaxis_title=r"$\mathrm{SYNTAX\ GT}$", |
| yaxis_title=r"$\mathrm{SYNTAX\ predictions}$", |
| width=PLOT_WIDTH, |
| height=PLOT_HEIGHT, |
| plot_bgcolor=PLOT_BG_COLOR, |
| paper_bgcolor=PAPER_BG_COLOR, |
| legend=dict( |
| x=LEGEND_X, |
| y=LEGEND_Y, |
| bgcolor=LEGEND_BG_COLOR, |
| bordercolor="#CBD5E1", |
| borderwidth=1, |
| font=dict(size=LEGEND_FONT_SIZE, family=base_font["family"]), |
| tracegroupgap=8, |
| itemclick="toggle", |
| itemdoubleclick="toggleothers", |
| groupclick="toggleitem", |
| ), |
| xaxis=dict( |
| showgrid=True, |
| gridcolor=GRID_COLOR, |
| gridwidth=1, |
| zeroline=False, |
| tickfont=dict(size=AXIS_TICK_FONT_SIZE), |
| range=[line_min, line_max], |
| constrain="domain", |
| ), |
| yaxis=dict( |
| showgrid=True, |
| gridcolor=GRID_COLOR, |
| gridwidth=1, |
| zeroline=False, |
| tickfont=dict(size=AXIS_TICK_FONT_SIZE), |
| range=[line_min, line_max], |
| scaleanchor="x", |
| scaleratio=1, |
| constrain="domain", |
| ), |
| margin=dict( |
| l=MARGIN_LEFT, |
| r=MARGIN_RIGHT, |
| t=MARGIN_TOP, |
| b=MARGIN_BOTTOM, |
| ), |
| ) |
|
|
| |
| save_dir = "visualizations" |
| if backbone: |
| save_dir = os.path.join(save_dir, "backbone") |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| postfix_html = f"{postfix}" if postfix else "syntax" |
| save_path_html = os.path.join(save_dir, f"{postfix_html}.html") |
| fig.write_html(save_path_html, include_mathjax="cdn") |
| print(f"Saved visualization with logistic trends: {save_path_html}") |
|
|