This repository has been archived by the owner on Jan 14, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
network.py
96 lines (79 loc) · 4.6 KB
/
network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from tensorflow.keras.applications.vgg19 import VGG19
import tensorflow.keras.backend as K
import tensorflow.keras.layers as layers
import tensorflow.keras.models as models
class ReflectionPad(layers.Layer):
def __init__(self, padding, name='reflection', *args, **kwargs):
super(ReflectionPad, self).__init__(name=name, **kwargs)
self.pad_left, self.pad_right, self.pad_top, self.pad_bottom = padding
def compute_output_shape(self, input_shape):
try:
super(ReflectionPad, self).compute_output_shape(input_shape=input_shape)
except NotImplementedError:
return (input_shape[0], input_shape[1] + self.pad_left + self.pad_right, input_shape[2] + self.pad_top + self.pad_bottom, input_shape[3])
def call(self, x):
x = K.concatenate([K.reverse(x, 1)[:, (-1 - self.pad_left):-1, :, :], x, K.reverse(x, 1)[:, 1:(1 + self.pad_right), :, :]], axis=1)
x = K.concatenate([K.reverse(x, 2)[:, :, (-1 - self.pad_top):-1, :], x, K.reverse(x, 2)[:, :, 1:(1 + self.pad_bottom), :]], axis=2)
return x
def get_config(self):
config = super(ReflectionPad, self).get_config().copy()
config.update({
'padding': (self.pad_left, self.pad_right, self.pad_top, self.pad_bottom)
})
return config
class AdaIN(layers.Layer):
def __init__(self, name='adain', alpha=1.0, epsilon=1e-5, **kwargs):
super(AdaIN, self).__init__(name=name, **kwargs)
self.alpha = alpha
self.epsilon = epsilon
def compute_output_shape(self, input_shape):
try:
super(AdaIN, self).compute_output_shape(input_shape=input_shape)
except NotImplementedError:
return input_shape[0]
def call(self, x):
content_features, style_features = x
content_mean = K.mean(content_features, axis=[1, 2], keepdims=True)
content_std = K.sqrt(K.var(content_features, axis=[1, 2], keepdims=True) + self.epsilon)
style_mean = K.mean(style_features, axis=[1, 2], keepdims=True)
style_std = K.sqrt(K.var(style_features, axis=[1, 2], keepdims=True) + self.epsilon)
normalized_content_features = (content_features - content_mean) / (content_std + self.epsilon) * style_std + style_mean
return self.alpha * normalized_content_features + (1 - self.alpha) * content_features
def get_config(self):
config = super(AdaIN, self).get_config().copy()
config.update({
'alpha': self.alpha,
'epsilon': self.epsilon
})
return config
def Encoder(input_tensor=None, input_shape=(256, 256, 3), pretrained=True, name='encoder', **kwargs):
if input_tensor is None:
input_tensor = layers.Input(shape=input_shape)
vgg = VGG19(input_tensor=input_tensor, weights='imagenet' if pretrained else None, include_top=False)
output_layers = [vgg.get_layer(layer_name) for layer_name in ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1']]
return models.Model(inputs=[input_tensor], outputs=[layer.output for layer in output_layers], name=name, **kwargs)
def Decoder(input_tensor=None, input_shape=(32, 32, 512), name='decoder', **kwargs):
if input_tensor is None:
input_tensor = layers.Input(shape=input_shape)
x = ReflectionPad((1, 1, 1, 1), name='block4_reflection1')(input_tensor)
x = layers.Conv2D(256, (3, 3), activation='relu', name='block4_conv1')(x)
x = layers.UpSampling2D(size=2, interpolation='nearest', name='block3_upsample')(x)
x = ReflectionPad((1, 1, 1, 1), name='block3_reflection4')(x)
x = layers.Conv2D(256, (3, 3), activation='relu', name='block3_conv4')(x)
x = ReflectionPad((1, 1, 1, 1), name='block3_reflection3')(x)
x = layers.Conv2D(256, (3, 3), activation='relu', name='block3_conv3')(x)
x = ReflectionPad((1, 1, 1, 1), name='block3_reflection2')(x)
x = layers.Conv2D(256, (3, 3), activation='relu', name='block3_conv2')(x)
x = ReflectionPad((1, 1, 1, 1), name='block3_reflection1')(x)
x = layers.Conv2D(128, (3, 3), activation='relu', name='block3_conv1')(x)
x = layers.UpSampling2D(size=2, interpolation='nearest', name='block2_upsample')(x)
x = ReflectionPad((1, 1, 1, 1), name='block2_reflection2')(x)
x = layers.Conv2D(128, (3, 3), activation='relu', name='block2_conv2')(x)
x = ReflectionPad((1, 1, 1, 1), name='block2_reflection1')(x)
x = layers.Conv2D(64, (3, 3), activation='relu', name='block2_conv1')(x)
x = layers.UpSampling2D(size=2, interpolation='nearest', name='block1_upsample')(x)
x = ReflectionPad((1, 1, 1, 1), name='block1_reflection2')(x)
x = layers.Conv2D(64, (3, 3), activation='relu', name='block1_conv2')(x)
x = ReflectionPad((1, 1, 1, 1), name='block1_reflection1')(x)
x = layers.Conv2D(3, (3, 3), name='block1_conv1')(x)
return models.Model(inputs=[input_tensor], outputs=[x], name=name, **kwargs)