예제 #1
0
    def __init__(self,
                 layer,
                 dim=0,
                 power_iters=1,
                 eps=1e-12,
                 dtype='float32'):
        super(Spectralnorm, self).__init__()

        self.dim = dim
        self.power_iters = power_iters
        self.eps = eps
        self.layer = layer
        self.has_bias = False
        self.is_fc = False
        if 'bias' in layer._parameters:
            bias = layer._parameters['bias']
            self.bias_orig = self.create_parameter(bias.shape,
                                                   dtype=bias.dtype)
            self.bias_orig.set_value(bias)
            self.has_bias = True
            del layer._parameters['bias']

        weight = layer._parameters['weight']
        self.weight_orig = self.create_parameter(weight.shape,
                                                 dtype=weight.dtype)
        self.weight_orig.set_value(weight)

        if isinstance(layer, dygraph.Linear):
            self.is_fc = True
            self.spectral_norm = dygraph.SpectralNorm(layer.weight.shape[::-1],
                                                      dim, power_iters, eps,
                                                      dtype)
        else:
            self.spectral_norm = dygraph.SpectralNorm(layer.weight.shape, dim,
                                                      power_iters, eps, dtype)
        del layer._parameters['weight']
 def __init__(self,
              layer,
              dim=0,
              power_iters=1,
              eps=1e-12,
              dtype='float32'):
     super(Spectralnorm, self).__init__()
     self.spectral_norm = dygraph.SpectralNorm(layer.weight.shape, dim,
                                               power_iters, eps, dtype)
     self.dim = dim
     self.power_iters = power_iters
     self.eps = eps
     self.layer = layer
     weight = layer._parameters['weight']
     del layer._parameters['weight']
     self.weight_orig = self.create_parameter(weight.shape,
                                              dtype=weight.dtype)
     self.weight_orig.set_value(weight)
예제 #3
0
    def __init__(self,
                 num_channels=3,
                 block_expansion=64,
                 num_blocks=4,
                 max_features=512,
                 sn=False,
                 use_kp=False,
                 num_kp=10,
                 kp_variance=0.01,
                 **kwargs):
        super(Discriminator, self).__init__()

        down_blocks = []
        for i in range(num_blocks):
            down_blocks.append(
                DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(
                    max_features, block_expansion * (2**i)),
                            min(max_features, block_expansion * (2**(i + 1))),
                            norm=(i != 0),
                            kernel_size=4,
                            pool=(i != num_blocks - 1),
                            sn=sn))

        self.down_blocks = dygraph.LayerList(down_blocks)
        self.conv = dygraph.Conv2D(
            self.down_blocks[len(self.down_blocks) -
                             1].conv.parameters()[0].shape[0],
            1,
            filter_size=1)
        if sn:
            self.sn = dygraph.SpectralNorm(self.conv.parameters()[0].shape,
                                           dim=0)
        else:
            self.sn = None
        self.use_kp = use_kp
        self.kp_variance = kp_variance
예제 #4
0
    def __init__(self,
                 in_features,
                 out_features,
                 norm=False,
                 kernel_size=4,
                 pool=False,
                 sn=False):
        super(DownBlock2d, self).__init__()
        self.conv = dygraph.Conv2D(in_features,
                                   out_features,
                                   filter_size=kernel_size)

        if sn:
            self.sn = dygraph.SpectralNorm(self.conv.weight.shape, dim=0)
        else:
            self.sn = None
        if norm:
            self.norm = dygraph.InstanceNorm(num_channels=out_features,
                                             epsilon=1e-05,
                                             dtype='float32')
        else:
            self.norm = None

        self.pool = pool