Beispiel #1
0
 def last_cnn(self, h, resolution, ch_in, ch_out):
     with nn.parameter_scope("phase_{}".format(resolution)):
         with nn.parameter_scope("conv1"):
             h = conv(h,
                      ch_in,
                      kernel=(3, 3),
                      pad=(1, 1),
                      stride=(1, 1),
                      with_bias=not self.use_ln,
                      use_wscale=self.use_wscale,
                      use_he_backward=self.use_he_backward)
             h = LN(h, use_ln=self.use_ln)
             h = self.activation(h)
         with nn.parameter_scope("conv2"):
             h = conv(h,
                      ch_out,
                      kernel=(4, 4),
                      pad=(0, 0),
                      stride=(1, 1),
                      with_bias=not self.use_ln,
                      use_wscale=self.use_wscale,
                      use_he_backward=self.use_he_backward)
             h = LN(h, use_ln=self.use_ln)
             h = self.activation(h)
         with nn.parameter_scope("linear"):
             h = affine(h,
                        1,
                        with_bias=True,
                        use_wscale=self.use_wscale,
                        use_he_backward=self.use_he_backward)
     return h
Beispiel #2
0
 def first_cnn(self, h, resolution, channel, test=False):
     with nn.parameter_scope("phase_{}".format(resolution)):
         # affine is 1x1 conv with 4x4 kernel and 3x3 pad.
         with nn.parameter_scope("conv1"):
             h = affine(h,
                        channel * 4 * 4,
                        with_bias=not self.use_bn,
                        use_wscale=self.use_wscale,
                        use_he_backward=self.use_he_backward)
             h = BN(h, use_bn=self.use_bn, test=test)
             h = F.reshape(h, (h.shape[0], channel, 4, 4))
             h = pixel_wise_feature_vector_normalization(
                 BN(h, use_bn=self.use_bn, test=test))
             h = self.activation(h)
         with nn.parameter_scope("conv2"):
             h = conv(h,
                      channel,
                      kernel=(3, 3),
                      pad=(1, 1),
                      stride=(1, 1),
                      with_bias=not self.use_bn,
                      use_wscale=self.use_wscale,
                      use_he_backward=self.use_he_backward)
             h = pixel_wise_feature_vector_normalization(
                 BN(h, use_bn=self.use_bn, test=test))
             h = self.activation(h)
     return h