| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from vgg_model import vgg19 |
|
|
| class DoubleConv(nn.Module): |
| """(convolution => [BN] => ReLU) * 2""" |
|
|
| def __init__(self, in_channels, out_channels, mid_channels=None): |
| super().__init__() |
| if not mid_channels: |
| mid_channels = out_channels |
| self.double_conv = nn.Sequential( |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(mid_channels), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.LeakyReLU(0.1, True) |
| ) |
|
|
| def forward(self, x): |
| x = self.double_conv(x) |
| return x |
|
|
| class ResBlock(nn.Module): |
| """(convolution => [BN] => ReLU) * 2""" |
|
|
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.bottle_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) |
| self.double_conv = nn.Sequential( |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
| ) |
|
|
| def forward(self, x): |
| x = self.bottle_conv(x) |
| x = self.double_conv(x) + x |
| return x / math.sqrt(2) |
|
|
|
|
| class Down(nn.Module): |
| """Downscaling with stride conv then double conv""" |
|
|
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.main = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, 4, 2, 1), |
| nn.LeakyReLU(0.1, True), |
| |
| ResBlock(in_channels, out_channels) |
| ) |
| |
|
|
| def forward(self, x): |
|
|
| x = self.main(x) |
|
|
| return x |
|
|
| class SDFT(nn.Module): |
|
|
| def __init__(self, color_dim, channels, kernel_size = 3): |
| super().__init__() |
| |
| |
| fan_in = channels * kernel_size ** 2 |
| self.kernel_size = kernel_size |
| self.padding = kernel_size // 2 |
|
|
| self.scale = 1 / math.sqrt(fan_in) |
| self.modulation = nn.Conv2d(color_dim, channels, 1) |
| self.weight = nn.Parameter( |
| torch.randn(1, channels, channels, kernel_size, kernel_size) |
| ) |
|
|
| def forward(self, fea, color_style): |
| |
| B, C, H, W = fea.size() |
| |
| style = self.modulation(color_style).view(B, 1, C, 1, 1) |
| weight = self.scale * self.weight * style |
| |
| demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) |
| weight = weight * demod.view(B, C, 1, 1, 1) |
|
|
| weight = weight.view( |
| B * C, C, self.kernel_size, self.kernel_size |
| ) |
|
|
| fea = fea.view(1, B * C, H, W) |
| fea = F.conv2d(fea, weight, padding=self.padding, groups=B) |
| fea = fea.view(B, C, H, W) |
|
|
| return fea |
|
|
|
|
| class UpBlock(nn.Module): |
| |
|
|
| def __init__(self, color_dim, in_channels, out_channels, kernel_size = 3, bilinear=True): |
| super().__init__() |
|
|
| |
| if bilinear: |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
| |
| else: |
| self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) |
|
|
| self.conv_cat = nn.Sequential( |
| nn.Conv2d(in_channels // 2 + in_channels // 8, out_channels, 1, 1, 0), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| nn.LeakyReLU(0.2, True) |
| ) |
|
|
| self.conv_s = nn.Conv2d(in_channels//2, out_channels, 1, 1, 0) |
|
|
| |
| self.SDFT = SDFT(color_dim, out_channels, kernel_size) |
|
|
|
|
| def forward(self, x1, x2, color_style): |
| |
| x1 = self.up(x1) |
| x1_s = self.conv_s(x1) |
|
|
| x = torch.cat([x1, x2[:, ::4, :, :]], dim=1) |
| x = self.conv_cat(x) |
| x = self.SDFT(x, color_style) |
|
|
| x = x + x1_s |
|
|
| return x |
|
|
|
|
| class ColorEncoder(nn.Module): |
| def __init__(self, color_dim=512): |
| super(ColorEncoder, self).__init__() |
|
|
| |
| self.vgg = vgg19() |
|
|
| self.feature2vector = nn.Sequential( |
| nn.Conv2d(color_dim, color_dim, 4, 2, 2), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(color_dim, color_dim, 3, 1, 1), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(color_dim, color_dim, 4, 2, 2), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(color_dim, color_dim, 3, 1, 1), |
| nn.LeakyReLU(0.2, True), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| nn.Conv2d(color_dim, color_dim//2, 1), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(color_dim//2, color_dim//2, 1), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(color_dim//2, color_dim, 1), |
| ) |
|
|
| self.color_dim = color_dim |
|
|
| def forward(self, x): |
| |
| vgg_fea = self.vgg(x, layer_name='relu5_2') |
|
|
| x_color = self.feature2vector(vgg_fea[-1]) |
|
|
| return x_color |
|
|
|
|
| class ColorUNet(nn.Module): |
| |
| def __init__(self, n_channels=1, n_classes=3, bilinear=True): |
| super(ColorUNet, self).__init__() |
| self.n_channels = n_channels |
| self.n_classes = n_classes |
| self.bilinear = bilinear |
|
|
| self.inc = DoubleConv(n_channels, 64) |
| self.down1 = Down(64, 128) |
| self.down2 = Down(128, 256) |
| self.down3 = Down(256, 512) |
| factor = 2 if bilinear else 1 |
| self.down4 = Down(512, 1024 // factor) |
|
|
| self.up1 = UpBlock(512, 1024, 512 // factor, 3, bilinear) |
| self.up2 = UpBlock(512, 512, 256 // factor, 3, bilinear) |
| self.up3 = UpBlock(512, 256, 128 // factor, 5, bilinear) |
| self.up4 = UpBlock(512, 128, 64, 5, bilinear) |
| self.outc = nn.Sequential( |
| nn.Conv2d(64, 64, 3, 1, 1), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(64, 2, 3, 1, 1), |
| nn.Tanh() |
| ) |
|
|
| def forward(self, x): |
| |
| |
|
|
| x_color = x[1] |
|
|
| x1 = self.inc(x[0]) |
| x2 = self.down1(x1) |
| x3 = self.down2(x2) |
| x4 = self.down3(x3) |
| x5 = self.down4(x4) |
|
|
| x6 = self.up1(x5, x4, x_color) |
| x7 = self.up2(x6, x3, x_color) |
| x8 = self.up3(x7, x2, x_color) |
| x9 = self.up4(x8, x1, x_color) |
| x_ab = self.outc(x9) |
|
|
| return x_ab |
|
|