-
Notifications
You must be signed in to change notification settings - Fork 1
/
CartoonGAN_model_modified.py
156 lines (118 loc) · 5.83 KB
/
CartoonGAN_model_modified.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# CartoonGAN implementation in PyTorch...
# + some modifications
# Modifications are:
# 1. Use InstanceNorm instead of BatchNorm: In image style transfer task, instance normalization tends to work better
# 2. Use LeakyReLu instead of ReLU in both Generator and Discriminator
# 3. Use other network for feature extraction. VGG is old, slow and inaccurate, so used ResNet
# 4. Try other loss function, like WGAN, Hinge or MSE -> To be implemented in other file.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvmodels
class ResidualBlock(nn.Module):
def __init__(self, channels=256, use_bias=False):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.InstanceNorm2d(channels),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.InstanceNorm2d(channels)
)
def forward(self, input):
residual = input
x = self.model(input)
# element-wise sum
out = x + residual
return out
class Generator(nn.Module):
def __init__(self, n_res_block=8, use_bias=False):
super().__init__()
# down sampling, or layers before residual blocks
self.down_sampling = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=use_bias),
nn.InstanceNorm2d(64),
nn.LeakyReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.InstanceNorm2d(128),
nn.LeakyReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.InstanceNorm2d(256),
nn.LeakyReLU(inplace=True)
)
# res_blocks
res_blocks = []
for i in range(n_res_block):
res_blocks.append(ResidualBlock(channels=256, use_bias=use_bias))
self.res_blocks = nn.Sequential(*res_blocks)
# up sapling, or layers after residual blocks
self.up_sampling = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.InstanceNorm2d(128),
nn.LeakyReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.InstanceNorm2d(64),
nn.LeakyReLU(inplace=True),
nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3, bias=use_bias),
nn.Tanh()
)
def forward(self, input):
x = self.down_sampling(input)
x = self.res_blocks(x)
out = self.up_sampling(x)
return out
class Discriminator(nn.Module):
def __init__(self, leaky_relu_negative_slope=0.2, use_bias=False):
super().__init__()
self.negative_slope = leaky_relu_negative_slope
self.layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=use_bias),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.BatchNorm2d(64),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.BatchNorm2d(128),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.BatchNorm2d(256),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.Sigmoid()
)
def forward(self, input):
output = self.layers(input)
return output
class FeatureExtractor(nn.Module):
def __init__(self, network='resnet-101'):
# in original paper, authors used vgg.
# however, there exist much better convolutional networks than vgg, and we may experiment with them
# possible models may be vgg, resnet, etc
super().__init__()
assert network in ['vgg', 'resnet-101']
if network == 'vgg':
vgg = tvmodels.vgg19_bn(pretrained=True)
self.feature_extractor = vgg.features[:37]
# vgg.features[36] is conv4_4, which is what authors used
# when input has shape [3, 512, 512], output of feature extractor is [512, 64, 64]
elif network == 'resnet-101':
# TODO
resnet = tvmodels.resnet101(pretrained=True)
layers = [resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1, resnet.layer2]
self.feature_extractor = nn.Sequential(*layers)
# when input has shape [3, 512, 512], output of feature extractor is [512, 64, 64]
# same output shape as vgg version.
# FeatureExtractor should not be trained
for child in self.feature_extractor.children():
for param in child.parameters():
param.requires_grad = False
def forward(self, input):
return self.feature_extractor(input)