| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
|
|
| from torch import Tensor |
|
|
|
|
| @dataclass |
| class PathSample: |
| r"""Represents a sample of a conditional-flow generated probability path. |
| |
| Attributes: |
| x_1 (Tensor): the target sample :math:`X_1`. |
| x_0 (Tensor): the source sample :math:`X_0`. |
| t (Tensor): the time sample :math:`t`. |
| x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). |
| dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). |
| |
| """ |
|
|
| x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) |
| x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) |
| t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) |
| x_t: Tensor = field( |
| metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} |
| ) |
| dx_t: Tensor = field( |
| metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} |
| ) |
|
|
|
|
| @dataclass |
| class DiscretePathSample: |
| """ |
| Represents a sample of a conditional-flow generated discrete probability path. |
| |
| Attributes: |
| x_1 (Tensor): the target sample :math:`X_1`. |
| x_0 (Tensor): the source sample :math:`X_0`. |
| t (Tensor): the time sample :math:`t`. |
| x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. |
| """ |
|
|
| x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) |
| x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) |
| t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) |
| x_t: Tensor = field( |
| metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} |
| ) |
|
|