| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | from torch.distributions import multinomial, categorical |
| | import torch.optim as optim |
| |
|
| | import math |
| |
|
| | try: |
| | from . import helpers as h |
| | from . import ai |
| | from . import scheduling as S |
| | except: |
| | import helpers as h |
| | import ai |
| | import scheduling as S |
| |
|
| | import math |
| | import abc |
| |
|
| | from torch.nn.modules.conv import _ConvNd |
| | from enum import Enum |
| |
|
| |
|
| | class InferModule(nn.Module): |
| | def __init__(self, *args, normal = False, ibp_init = False, **kwargs): |
| | self.args = args |
| | self.kwargs = kwargs |
| | self.infered = False |
| | self.normal = normal |
| | self.ibp_init = ibp_init |
| |
|
| | def infer(self, in_shape, global_args = None): |
| | """ this is really actually stateful. """ |
| |
|
| | if self.infered: |
| | return self |
| | self.infered = True |
| |
|
| | super(InferModule, self).__init__() |
| | self.inShape = list(in_shape) |
| | self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs)) |
| | if self.outShape is None: |
| | raise "init should set the out_shape" |
| | |
| | self.reset_parameters() |
| | return self |
| | |
| | def reset_parameters(self): |
| | if not hasattr(self,'weight') or self.weight is None: |
| | return |
| | n = h.product(self.weight.size()) / self.outShape[0] |
| | stdv = 1 / math.sqrt(n) |
| | |
| | if self.ibp_init: |
| | torch.nn.init.orthogonal_(self.weight.data) |
| | elif self.normal: |
| | self.weight.data.normal_(0, stdv) |
| | self.weight.data.clamp_(-1, 1) |
| | else: |
| | self.weight.data.uniform_(-stdv, stdv) |
| |
|
| | if self.bias is not None: |
| | if self.ibp_init: |
| | self.bias.data.zero_() |
| | elif self.normal: |
| | self.bias.data.normal_(0, stdv) |
| | self.bias.data.clamp_(-1, 1) |
| | else: |
| | self.bias.data.uniform_(-stdv, stdv) |
| |
|
| | def clip_norm(self): |
| | if not hasattr(self, "weight"): |
| | return |
| | if not hasattr(self,"weight_g"): |
| | if torch.__version__[0] == "0": |
| | nn.utils.weight_norm(self, dim=None) |
| | else: |
| | nn.utils.weight_norm(self) |
| | |
| | self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm) |
| |
|
| | if torch.__version__[0] != "0": |
| | self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000) |
| | if hasattr(self, "bias"): |
| | self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000) |
| |
|
| | def regularize(self, p): |
| | reg = 0 |
| | if torch.__version__[0] == "0": |
| | for param in self.parameters(): |
| | reg += param.norm(p) |
| | else: |
| | if hasattr(self, "weight_g"): |
| | reg += self.weight_g.norm().sum() |
| | reg += self.weight_v.norm().sum() |
| | elif hasattr(self, "weight"): |
| | reg += self.weight.norm().sum() |
| |
|
| | if hasattr(self, "bias"): |
| | reg += self.bias.view(-1).norm(p=p).sum() |
| |
|
| | return reg |
| |
|
| | def remove_norm(self): |
| | if hasattr(self,"weight_g"): |
| | torch.nn.utils.remove_weight_norm(self) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__) |
| |
|
| | def printNet(self, f): |
| | print(self.__class__.__name__, file=f) |
| |
|
| | @abc.abstractmethod |
| | def forward(self, *args, **kargs): |
| | pass |
| |
|
| | def __call__(self, *args, onyx=False, **kargs): |
| | if onyx: |
| | return self.forward(*args, onyx=onyx, **kargs) |
| | else: |
| | return super(InferModule, self).__call__(*args, **kargs) |
| | |
| | @abc.abstractmethod |
| | def neuronCount(self): |
| | pass |
| |
|
| | def depth(self): |
| | return 0 |
| |
|
| | def getShapeConv(in_shape, conv_shape, stride = 1, padding = 0): |
| | inChan, inH, inW = in_shape |
| | outChan, kH, kW = conv_shape[:3] |
| |
|
| | outH = 1 + int((2 * padding + inH - kH) / stride) |
| | outW = 1 + int((2 * padding + inW - kW) / stride) |
| | return (outChan, outH, outW) |
| |
|
| | def getShapeConvTranspose(in_shape, conv_shape, stride = 1, padding = 0, out_padding=0): |
| | inChan, inH, inW = in_shape |
| | outChan, kH, kW = conv_shape[:3] |
| |
|
| | outH = (inH - 1 ) * stride - 2 * padding + kH + out_padding |
| | outW = (inW - 1 ) * stride - 2 * padding + kW + out_padding |
| | return (outChan, outH, outW) |
| |
|
| |
|
| |
|
| | class Linear(InferModule): |
| | def init(self, in_shape, out_shape, **kargs): |
| | self.in_neurons = h.product(in_shape) |
| | if isinstance(out_shape, int): |
| | out_shape = [out_shape] |
| | self.out_neurons = h.product(out_shape) |
| | |
| | self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons)) |
| | self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons)) |
| |
|
| | return out_shape |
| |
|
| | def forward(self, x, **kargs): |
| | s = x.size() |
| | x = x.view(s[0], h.product(s[1:])) |
| | return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape) |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| | def showNet(self, t = ""): |
| | print(t + "Linear out=" + str(self.out_neurons)) |
| |
|
| | def printNet(self, f): |
| | print("Linear(" + str(self.out_neurons) + ")" ) |
| |
|
| | print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f) |
| | print(h.printNumpy(self.bias), file= f) |
| |
|
| | class Activation(InferModule): |
| | def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): |
| | self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation) |
| | self.activation_name = activation |
| | return in_shape |
| |
|
| | def regularize(self, p): |
| | return 0 |
| |
|
| | def forward(self, x, **kargs): |
| | return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x) |
| |
|
| | def neuronCount(self): |
| | return h.product(self.outShape) |
| |
|
| | def depth(self): |
| | return 1 |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.activation_name) |
| |
|
| | def printNet(self, f): |
| | pass |
| |
|
| | class ReLU(Activation): |
| | pass |
| |
|
| | def activation(*args, batch_norm = False, **kargs): |
| | a = Activation(*args, **kargs) |
| | return Seq(BatchNorm(), a) if batch_norm else a |
| |
|
| | class Identity(InferModule): |
| | def init(self, in_shape, global_args = None, **kargs): |
| | return in_shape |
| |
|
| | def forward(self, x, **kargs): |
| | return x |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| | def printNet(self, f): |
| | pass |
| |
|
| | def regularize(self, p): |
| | return 0 |
| |
|
| | def showNet(self, *args, **kargs): |
| | pass |
| |
|
| | class Dropout(InferModule): |
| | def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs): |
| | self.p = S.Const.initConst(p) |
| | self.use_2d = use_2d |
| | self.alpha_dropout = alpha_dropout |
| | return in_shape |
| |
|
| | def forward(self, x, time = 0, **kargs): |
| | if self.training: |
| | with torch.no_grad(): |
| | p = self.p.getVal(time = time) |
| | mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True) |
| | if self.alpha_dropout: |
| | with torch.no_grad(): |
| | keep_prob = 1 - p |
| | alpha = -1.7580993408473766 |
| | a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5) |
| | b = -a * alpha * (1 - keep_prob) |
| | mask = mask * a |
| | return x * mask + b |
| | else: |
| | return x * mask |
| | else: |
| | return x |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| | def showNet(self, t = ""): |
| | print(t + "Dropout p=" + str(self.p)) |
| |
|
| | def printNet(self, f): |
| | print("Dropout(" + str(self.p) + ")" ) |
| |
|
| | class PrintActivation(Identity): |
| | def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): |
| | self.activation = activation |
| | return in_shape |
| |
|
| | def printNet(self, f): |
| | print(self.activation, file = f) |
| |
|
| | class PrintReLU(PrintActivation): |
| | pass |
| |
|
| | class Conv2D(InferModule): |
| |
|
| | def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs): |
| | self.prev = in_shape |
| | self.in_channels = in_shape[0] |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| | self.padding = padding |
| | self.activation = activation |
| | self.use_softplus = h.default(global_args, 'use_softplus', False) |
| | |
| | weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size) |
| | self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) |
| | if bias: |
| | self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) |
| | else: |
| | self.bias = None |
| | |
| | outshape = getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding) |
| | return outshape |
| | |
| | def forward(self, input, **kargs): |
| | return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding ) |
| | |
| | def printNet(self, f): |
| | print("Conv2D", file = f) |
| | sz = list(self.prev) |
| | print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f) |
| | print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) |
| | print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f) |
| |
|
| | def showNet(self, t = ""): |
| | sz = list(self.prev) |
| | print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding )) |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| |
|
| | class ConvTranspose2D(InferModule): |
| |
|
| | def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs): |
| | self.prev = in_shape |
| | self.in_channels = in_shape[0] |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| | self.padding = padding |
| | self.out_padding = out_padding |
| | self.activation = activation |
| | self.use_softplus = h.default(global_args, 'use_softplus', False) |
| | |
| | weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size) |
| | self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) |
| | if bias: |
| | self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) |
| | else: |
| | self.bias = None |
| | |
| | outshape = getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding) |
| | return outshape |
| |
|
| | def forward(self, input, **kargs): |
| | return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding) |
| | |
| | def printNet(self, f): |
| | print("ConvTranspose2D", file = f) |
| | print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f) |
| | print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) |
| | print(h.printNumpy(self.bias), file= f) |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| |
|
| |
|
| | class MaxPool2D(InferModule): |
| | def init(self, in_shape, kernel_size, stride = None, **kargs): |
| | self.prev = in_shape |
| | self.kernel_size = kernel_size |
| | self.stride = kernel_size if stride is None else stride |
| | return getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride) |
| |
|
| | def forward(self, x, **kargs): |
| | return x.max_pool2d(self.kernel_size, self.stride) |
| | |
| | def printNet(self, f): |
| | print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) |
| | |
| | def neuronCount(self): |
| | return h.product(self.outShape) |
| |
|
| | class AvgPool2D(InferModule): |
| | def init(self, in_shape, kernel_size, stride = None, **kargs): |
| | self.prev = in_shape |
| | self.kernel_size = kernel_size |
| | self.stride = kernel_size if stride is None else stride |
| | out_size = getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1) |
| | return out_size |
| |
|
| | def forward(self, x, **kargs): |
| | if h.product(x.size()[2:]) == 1: |
| | return x |
| | return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1) |
| | |
| | def printNet(self, f): |
| | print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) |
| | |
| | def neuronCount(self): |
| | return h.product(self.outShape) |
| |
|
| | class AdaptiveAvgPool2D(InferModule): |
| | def init(self, in_shape, out_shape, **kargs): |
| | self.prev = in_shape |
| | self.out_shape = list(out_shape) |
| | return [in_shape[0]] + self.out_shape |
| |
|
| | def forward(self, x, **kargs): |
| | return x.adaptive_avg_pool2d(self.out_shape) |
| | |
| | def printNet(self, f): |
| | print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f) |
| | |
| | def neuronCount(self): |
| | return h.product(self.outShape) |
| |
|
| | class Normalize(InferModule): |
| | def init(self, in_shape, mean, std, **kargs): |
| | self.mean_v = mean |
| | self.std_v = std |
| | self.mean = h.dten(mean) |
| | self.std = 1 / h.dten(std) |
| | return in_shape |
| |
|
| | def forward(self, x, **kargs): |
| | mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:]) |
| | std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:]) |
| | return (x - mean_ex) * std_ex |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| | def printNet(self, f): |
| | print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v)) |
| |
|
| | class Flatten(InferModule): |
| | def init(self, in_shape, **kargs): |
| | return h.product(in_shape) |
| | |
| | def forward(self, x, **kargs): |
| | s = x.size() |
| | return x.view(s[0], h.product(s[1:])) |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| | class BatchNorm(InferModule): |
| | def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs): |
| | self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape)) |
| | self.beta = torch.nn.Parameter(torch.Tensor(*in_shape)) |
| | self.eps = eps |
| | self.track_running_stats = track_running_stats |
| | self.momentum = momentum |
| |
|
| | self.running_mean = None |
| | self.running_var = None |
| |
|
| | self.num_batches_tracked = 0 |
| | return in_shape |
| |
|
| | def reset_parameters(self): |
| | self.gamma.data.fill_(1) |
| | self.beta.data.zero_() |
| |
|
| | def forward(self, x, **kargs): |
| | exponential_average_factor = 0.0 |
| | if self.training and self.track_running_stats: |
| | |
| | if self.num_batches_tracked is not None: |
| | self.num_batches_tracked += 1 |
| | if self.momentum is None: |
| | exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
| | else: |
| | exponential_average_factor = self.momentum |
| |
|
| | new_mean = x.vanillaTensorPart().detach().mean(dim=0) |
| | new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False) |
| | if torch.isnan(new_var * 0).any(): |
| | return x |
| | if self.training: |
| | self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean |
| | if self.running_var is None: |
| | self.running_var = new_var |
| | else: |
| | q = (1 - exponential_average_factor) * self.running_var |
| | r = exponential_average_factor * new_var |
| | self.running_var = q + r |
| |
|
| | if self.track_running_stats and self.running_mean is not None and self.running_var is not None: |
| | new_mean = self.running_mean |
| | new_var = self.running_var |
| | |
| | diver = 1 / (new_var + self.eps).sqrt() |
| |
|
| | if torch.isnan(diver).any(): |
| | print("Really shouldn't happen ever") |
| | return x |
| | else: |
| | out = (x - new_mean) * diver * self.gamma + self.beta |
| | return out |
| | |
| | def neuronCount(self): |
| | return 0 |
| |
|
| | class Unflatten2d(InferModule): |
| | def init(self, in_shape, w, **kargs): |
| | self.w = w |
| | self.outChan = int(h.product(in_shape) / (w * w)) |
| | |
| | return (self.outChan, self.w, self.w) |
| | |
| | def forward(self, x, **kargs): |
| | s = x.size() |
| | return x.view(s[0], self.outChan, self.w, self.w) |
| |
|
| | def neuronCount(self): |
| | return 0 |
| |
|
| |
|
| | class View(InferModule): |
| | def init(self, in_shape, out_shape, **kargs): |
| | assert(h.product(in_shape) == h.product(out_shape)) |
| | return out_shape |
| | |
| | def forward(self, x, **kargs): |
| | s = x.size() |
| | return x.view(s[0], *self.outShape) |
| |
|
| | def neuronCount(self): |
| | return 0 |
| | |
| | class Seq(InferModule): |
| | def init(self, in_shape, *layers, **kargs): |
| | self.layers = layers |
| | self.net = nn.Sequential(*layers) |
| | self.prev = in_shape |
| | for s in layers: |
| | in_shape = s.infer(in_shape, **kargs).outShape |
| | return in_shape |
| | |
| | def forward(self, x, **kargs): |
| | |
| | for l in self.layers: |
| | x = l(x, **kargs) |
| | return x |
| |
|
| | def clip_norm(self): |
| | for l in self.layers: |
| | l.clip_norm() |
| |
|
| | def regularize(self, p): |
| | return sum(n.regularize(p) for n in self.layers) |
| |
|
| | def remove_norm(self): |
| | for l in self.layers: |
| | l.remove_norm() |
| |
|
| | def printNet(self, f): |
| | for l in self.layers: |
| | l.printNet(f) |
| |
|
| | def showNet(self, *args, **kargs): |
| | for l in self.layers: |
| | l.showNet(*args, **kargs) |
| |
|
| | def neuronCount(self): |
| | return sum([l.neuronCount() for l in self.layers ]) |
| |
|
| | def depth(self): |
| | return sum([l.depth() for l in self.layers ]) |
| | |
| | def FFNN(layers, last_lin = False, last_zono = False, **kargs): |
| | starts = layers |
| | ends = [] |
| | if last_lin: |
| | ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)] |
| | starts = layers[:-1] |
| | |
| | return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends)) |
| |
|
| | def Conv(*args, **kargs): |
| | return Seq(Conv2D(*args, **kargs), activation(**kargs)) |
| |
|
| | def ConvTranspose(*args, **kargs): |
| | return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs)) |
| |
|
| | MP = MaxPool2D |
| |
|
| | def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs): |
| | def transfer(tp): |
| | if isinstance(tp, InferModule): |
| | return tp |
| | if isinstance(tp[0], str): |
| | return MaxPool2D(*tp[1:]) |
| | return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs) |
| | conv = [transfer(s) for s in conv_layers] |
| | return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv) |
| |
|
| | def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs): |
| | def transfer(tp): |
| | return ConvTranspose(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal) |
| | |
| | return Seq(FFNN(ly, bias=bias), Unflatten2d(w), *[transfer(s) for s in conv_layers]) |
| |
|
| | class FromByteImg(InferModule): |
| | def init(self, in_shape, **kargs): |
| | return in_shape |
| | |
| | def forward(self, x, **kargs): |
| | return x.to_dtype()/ 256. |
| |
|
| | def neuronCount(self): |
| | return 0 |
| | |
| | class Skip(InferModule): |
| | def init(self, in_shape, net1, net2, **kargs): |
| | self.net1 = net1.infer(in_shape, **kargs) |
| | self.net2 = net2.infer(in_shape, **kargs) |
| | assert(net1.outShape[1:] == net2.outShape[1:]) |
| | return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:] |
| | |
| | def forward(self, x, **kargs): |
| | r1 = self.net1(x, **kargs) |
| | r2 = self.net2(x, **kargs) |
| | return r1.cat(r2, dim=1) |
| |
|
| | def regularize(self, p): |
| | return self.net1.regularize(p) + self.net2.regularize(p) |
| |
|
| | def clip_norm(self): |
| | self.net1.clip_norm() |
| | self.net2.clip_norm() |
| |
|
| | def remove_norm(self): |
| | self.net1.remove_norm() |
| | self.net2.remove_norm() |
| |
|
| | def neuronCount(self): |
| | return self.net1.neuronCount() + self.net2.neuronCount() |
| |
|
| | def printNet(self, f): |
| | print("SkipNet1", file=f) |
| | self.net1.printNet(f) |
| | print("SkipNet2", file=f) |
| | self.net2.printNet(f) |
| | print("SkipCat dim=1", file=f) |
| |
|
| | def showNet(self, t = ""): |
| | print(t+"SkipNet1") |
| | self.net1.showNet(" "+t) |
| | print(t+"SkipNet2") |
| | self.net2.showNet(" "+t) |
| | print(t+"SkipCat dim=1") |
| |
|
| | class ParSum(InferModule): |
| | def init(self, in_shape, net1, net2, **kargs): |
| | self.net1 = net1.infer(in_shape, **kargs) |
| | self.net2 = net2.infer(in_shape, **kargs) |
| | assert(net1.outShape == net2.outShape) |
| | return net1.outShape |
| | |
| |
|
| |
|
| | def forward(self, x, **kargs): |
| | |
| | r1 = self.net1(x, **kargs) |
| | r2 = self.net2(x, **kargs) |
| | return x.addPar(r1,r2) |
| |
|
| | def clip_norm(self): |
| | self.net1.clip_norm() |
| | self.net2.clip_norm() |
| |
|
| | def remove_norm(self): |
| | self.net1.remove_norm() |
| | self.net2.remove_norm() |
| |
|
| | def neuronCount(self): |
| | return self.net1.neuronCount() + self.net2.neuronCount() |
| | |
| | def depth(self): |
| | return max(self.net1.depth(), self.net2.depth()) |
| |
|
| | def printNet(self, f): |
| | print("ParNet1", file=f) |
| | self.net1.printNet(f) |
| | print("ParNet2", file=f) |
| | self.net2.printNet(f) |
| | print("ParCat dim=1", file=f) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + "ParNet1") |
| | self.net1.showNet(" "+t) |
| | print(t + "ParNet2") |
| | self.net2.showNet(" "+t) |
| | print(t + "ParSum") |
| |
|
| | class ToZono(Identity): |
| | def init(self, in_shape, customRelu = None, only_train = False, **kargs): |
| | self.customRelu = customRelu |
| | self.only_train = only_train |
| | return in_shape |
| |
|
| | def forward(self, x, **kargs): |
| | return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x |
| |
|
| | def abstract_forward(self, x, **kargs): |
| | return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train)) |
| |
|
| | class CorrelateAll(ToZono): |
| | def abstract_forward(self, x, **kargs): |
| | return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu) |
| |
|
| | class ToHZono(ToZono): |
| | def abstract_forward(self, x, **kargs): |
| | return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu) |
| |
|
| | class Concretize(ToZono): |
| | def init(self, in_shape, only_train = True, **kargs): |
| | self.only_train = only_train |
| | return in_shape |
| |
|
| | def abstract_forward(self, x, **kargs): |
| | return x.abstractApplyLeaf('concretize') |
| |
|
| | |
| | class CorrRand(Concretize): |
| | def init(self, in_shape, num_correlate, only_train = True, **kargs): |
| | self.only_train = only_train |
| | self.num_correlate = num_correlate |
| | return in_shape |
| | |
| | def abstract_forward(self, x): |
| | return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate)) |
| |
|
| | class CorrMaxK(CorrRand): |
| | def abstract_forward(self, x): |
| | return x.abstractApplyLeaf("correlateMaxK", self.num_correlate) |
| |
|
| |
|
| | class CorrMaxPool2D(Concretize): |
| | def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs): |
| | self.only_train = only_train |
| | self.kernel_size = kernel_size |
| | self.max_type = max_type |
| | return in_shape |
| | |
| | def abstract_forward(self, x): |
| | return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type)) |
| |
|
| | class CorrMaxPool3D(Concretize): |
| | def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs): |
| | self.only_train = only_train |
| | self.kernel_size = kernel_size |
| | self.max_type = max_type |
| | return in_shape |
| | |
| | def abstract_forward(self, x): |
| | return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type) |
| |
|
| | class CorrFix(Concretize): |
| | def init(self,in_shape, k, only_train = True, **kargs): |
| | self.k = k |
| | self.only_train = only_train |
| | return in_shape |
| | |
| | def abstract_forward(self, x): |
| | sz = x.size() |
| | """ |
| | # for more control in the future |
| | indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) ) |
| | indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) ) |
| | indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) ) |
| | |
| | indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3) |
| | """ |
| | szm = h.product(sz[1:]) |
| | indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k)) |
| | indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0]) |
| |
|
| | |
| | return x.abstractApplyLeaf("correlate", indxs) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k)) |
| |
|
| |
|
| | class DecorrRand(Concretize): |
| | def init(self, in_shape, num_decorrelate, only_train = True, **kargs): |
| | self.only_train = only_train |
| | self.num_decorrelate = num_decorrelate |
| | return in_shape |
| | |
| | def abstract_forward(self, x): |
| | return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate) |
| |
|
| | class DecorrMin(Concretize): |
| | def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs): |
| | self.only_train = only_train |
| | self.num_decorrelate = num_decorrelate |
| | self.num_to_keep = num_to_keep |
| | return in_shape |
| | |
| | def abstract_forward(self, x): |
| | return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep) |
| |
|
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) ) |
| |
|
| | class DeepLoss(ToZono): |
| | def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): |
| | self.only_train = True |
| | self.bw = S.Const.initConst(bw) |
| | self.act = act |
| | return in_shape |
| |
|
| | def abstract_forward(self, x, **kargs): |
| | if x.isPoint(): |
| | return x |
| | return ai.TaggedDomain(x, self.MLoss(self, x)) |
| |
|
| | class MLoss(): |
| | def __init__(self, obj, x): |
| | self.obj = obj |
| | self.x = x |
| |
|
| | def loss(self, a, *args, lr = 1, time = 0, **kargs): |
| | bw = self.obj.bw.getVal(time = time) |
| | pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw)) |
| | if bw <= 0.0: |
| | return pre_loss |
| | return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act) |
| |
|
| | def showNet(self, t = ""): |
| | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) ) |
| |
|
| | class IdentLoss(DeepLoss): |
| | def abstract_forward(self, x, **kargs): |
| | return x |
| | |
| | def SkipNet(net1, net2, ffnn, **kargs): |
| | return Seq(Skip(net1,net2), FFNN(ffnn, **kargs)) |
| |
|
| | def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs): |
| | if not downsample: |
| | k_first = 3 |
| | skip_stride = 1 |
| | k_skip = 1 |
| | else: |
| | k_first = 4 |
| | skip_stride = 2 |
| | k_skip = 2 |
| |
|
| | |
| | blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs) |
| |
|
| | |
| | blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs) |
| | , Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) |
| | return Seq(ParSum(blockA, blockB), activation(**kargs)) |
| |
|
| |
|
| |
|
| | def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs): |
| | block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs) |
| | , Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) |
| |
|
| | if stride != 1 or in_planes != planes: |
| | block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs)) |
| | elif not skip_net: |
| | block = ParSum(block, Identity()) |
| | return Seq(block, activation(**kargs)) |
| |
|
| | |
| | def ResNet(blocksList, extra = [], bias = False, **kargs): |
| |
|
| | layers = [] |
| | in_planes = 64 |
| | planes = 64 |
| | stride = 0 |
| | for num_blocks in blocksList: |
| | if stride < 2: |
| | stride += 1 |
| |
|
| | strides = [stride] + [1]*(num_blocks-1) |
| | for stride in strides: |
| | layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs)) |
| | in_planes = planes |
| | planes *= 2 |
| |
|
| | print("RESlayers: ", len(layers)) |
| | for e,l in extra: |
| | layers[l] = Seq(layers[l], e) |
| | |
| | return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True), |
| | *layers) |
| |
|
| |
|
| |
|
| | def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True): |
| |
|
| | def Bottleneck(growthRate): |
| | interChannels = 4*growthRate |
| |
|
| | n = Seq( ReLU(), |
| | Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True), |
| | ReLU(), |
| | Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True) |
| | ) |
| |
|
| | return Skip(Identity(), n) |
| |
|
| | def SingleLayer(growthRate): |
| | n = Seq( ReLU(), |
| | Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)) |
| | return Skip(Identity(), n) |
| |
|
| | def Transition(nOutChannels): |
| | return Seq( ReLU(), |
| | Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True), |
| | AvgPool2D(kernel_size=2)) |
| |
|
| | def make_dense(growthRate, nDenseBlocks, bottleneck): |
| | return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)]) |
| |
|
| | nDenseBlocks = (depth-4) // 3 |
| | if bottleneck: |
| | nDenseBlocks //= 2 |
| |
|
| | nChannels = 2*growthRate |
| | conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True) |
| | dense1 = make_dense(growthRate, nDenseBlocks, bottleneck) |
| | nChannels += nDenseBlocks * growthRate |
| | nOutChannels = int(math.floor(nChannels*reduction)) |
| | trans1 = Transition(nOutChannels) |
| |
|
| | nChannels = nOutChannels |
| | dense2 = make_dense(growthRate, nDenseBlocks, bottleneck) |
| | nChannels += nDenseBlocks*growthRate |
| | nOutChannels = int(math.floor(nChannels*reduction)) |
| | trans2 = Transition(nOutChannels) |
| | |
| | nChannels = nOutChannels |
| | dense3 = make_dense(growthRate, nDenseBlocks, bottleneck) |
| |
|
| | return Seq(conv1, dense1, trans1, dense2, trans2, dense3, |
| | ReLU(), |
| | AvgPool2D(kernel_size=8), |
| | CorrelateAll(only_train=False, ignore_point = True), |
| | Linear(num_classes, ibp_init = True)) |
| |
|
| |
|