Ejemplo n.º 1
0
    def __init__(self, in_channels, num_features, gen_blocks, dis_blocks,
                 growth_rate, bn_size):
        super(SNDiscriminator, self).__init__()
        self.crop_size = 4 * pow(2, dis_blocks)

        # image to features
        self.image_to_features = DisBlock(in_channels=in_channels,
                                          out_channels=num_features,
                                          bias=True,
                                          normalization=True)

        # features
        blocks = []
        for i in range(0, dis_blocks - 1):
            blocks.append(
                DisBlock(in_channels=num_features * min(pow(2, i), 8),
                         out_channels=num_features * min(pow(2, i + 1), 8),
                         bias=False,
                         normalization=True))
        self.features = nn.Sequential(*blocks)

        # classifier
        self.classifier = nn.Sequential(
            SpectralNorm(
                nn.Linear(
                    num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4,
                    100)), nn.LeakyReLU(negative_slope=0.1),
            SpectralNorm(nn.Linear(100, 1)))
    def __init__(self, mode, cnn_chunk, out_c, k, h_dim):
        super().__init__()

        self.conv = cnn_chunk
        self.k = k
        self.mode = mode

        # Discriminator
        self.disc = nn.Sequential(
            nn.Linear(out_c, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1),
        )
        if 'sn' in mode:
            for module in self.disc.modules():
                if isinstance(module, nn.Linear):
                    SpectralNorm.apply(module,
                                       'weight',
                                       n_power_iterations=1,
                                       eps=1e-12,
                                       dim=0)
        else:
            assert mode == 'wgan-gp'
Ejemplo n.º 3
0
    def __init__(self,
                 in_channels=64,
                 out_channels=64,
                 bias=True,
                 normalization=False):
        super(DisBlock, self).__init__()
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=3,
                               padding=1,
                               bias=bias)
        self.conv2 = nn.Conv2d(in_channels=out_channels,
                               out_channels=out_channels,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=bias)
        self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
        self.bn2 = nn.BatchNorm2d(out_channels, affine=True)

        initialize_weights([self.conv1, self.conv2], 0.1)

        if normalization:
            self.conv1 = SpectralNorm(self.conv1)
            self.conv2 = SpectralNorm(self.conv2)
Ejemplo n.º 4
0
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]] = (3, 3),
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Union[int, Tuple[int, int]] = 0,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        n_power_iterations: int = 1,
    ) -> None:
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )

        SpectralNorm.apply(
            module=self,
            n_power_iterations=n_power_iterations,
            name="weight",
            dim=0,
            eps=1e-12
        )
Ejemplo n.º 5
0
def spectral_norm(module,
                  name="weight",
                  n_power_iterations=1,
                  eps=1e-12,
                  dim=None):
    if is_conv(module):
        ConvSpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    else:
        if dim is None:
            dim = 0
        SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module
Ejemplo n.º 6
0
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
    """Copied from https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html"""
    # Extra check of hasattr for classes with no 'weight'
    if hasattr(module, name):
        if dim is None:
            if isinstance(module, (torch.nn.ConvTranspose1d,
                                   torch.nn.ConvTranspose2d,
                                   torch.nn.ConvTranspose3d)):
                dim = 1
            else:
                dim = 0
        SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module
Ejemplo n.º 7
0
 def __init__(
     self,
     name="weight",
     n_power_iterations=1,
     dim=1,
     eps=1e-12,
     mode="",
     strict=True,
 ):
     super().__init__(name, n_power_iterations, dim, eps)
     assert validate_mode(mode)
     self.alpha, self.mode = parse_mode(mode)
     self.strict = strict
     self.dense = SpectralNorm(name, n_power_iterations, dim, eps)
     self.conv = ConvSpectralNorm(name, n_power_iterations, None, eps)
Ejemplo n.º 8
0
    def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
        super(SNDiscriminator, self).__init__()

        # image to features
        self.image_to_features = DisBlock(in_channels=in_channels,
                                          out_channels=num_features,
                                          bias=True,
                                          normalization=True)

        # features
        blocks = []
        for i in range(0, dis_blocks - 1):
            blocks.append(
                DisBlock(in_channels=num_features * min(pow(2, i), 8),
                         out_channels=num_features * min(pow(2, i + 1), 8),
                         bias=False,
                         normalization=True))
        self.features = nn.Sequential(*blocks)

        # classifier
        self.classifier = SpectralNorm(
            nn.Conv2d(in_channels=num_features *
                      min(pow(2, dis_blocks - 1), 8),
                      out_channels=1,
                      kernel_size=4,
                      padding=0))
    def __init__(self,
                 input_nc=3,
                 ndf=64,
                 img_f=1024,
                 layers=6,
                 norm='none',
                 activation='LeakyReLU',
                 use_spect=True,
                 use_coord=False):
        super(ResDiscriminator, self).__init__()

        self.layers = layers

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        self.nonlinearity = nonlinearity

        # encoder part
        self.block0 = ResBlockEncoder(input_nc, ndf, ndf, norm_layer,
                                      nonlinearity, use_spect, use_coord)

        mult = 1
        for i in range(layers - 1):
            mult_prev = mult
            mult = min(2**(i + 1), img_f // ndf)
            block = ResBlockEncoder(ndf * mult_prev, ndf * mult,
                                    ndf * mult_prev, norm_layer, nonlinearity,
                                    use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)
        self.conv = SpectralNorm(nn.Conv2d(ndf * mult, 1, 1))
    def __init__(self,
                 input_nc=3,
                 input_length=6,
                 ndf=64,
                 img_f=1024,
                 layers=6,
                 norm='none',
                 activation='LeakyReLU',
                 use_spect=True,
                 use_coord=False):
        super(TemporalDiscriminator, self).__init__()

        self.layers = layers
        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        self.nonlinearity = nonlinearity

        # self.pool = nn.AvgPool3d(kernel_size=(1,2,2), stride=(1,2,2))

        # encoder part
        self.block0 = ResBlock3DEncoder(input_nc, 1 * ndf, 1 * ndf, norm_layer,
                                        nonlinearity, use_spect, use_coord)
        self.block1 = ResBlock3DEncoder(1 * ndf, 2 * ndf, 1 * ndf, norm_layer,
                                        nonlinearity, use_spect, use_coord)

        feature_len = input_length - 4
        mult = 2 * feature_len
        for i in range(layers - 2):
            mult_prev = mult
            mult = min(2**(i + 2), img_f // ndf)
            block = ResBlockEncoder(ndf * mult_prev, ndf * mult,
                                    ndf * mult_prev, norm_layer, nonlinearity,
                                    use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)
        self.conv = SpectralNorm(nn.Conv2d(ndf * mult, 1, 1))
Ejemplo n.º 11
0
def spectral_norm(module,
                  name='weight',
                  apply_to=['conv2d'],
                  n_power_iterations=1,
                  eps=1e-12):
    # Apply only to modules in apply_to list
    module_name = module.__class__.__name__.lower()
    if module_name not in apply_to or 'adaptive' in module_name:
        return module

    if isinstance(module, nn.ConvTranspose2d):
        dim = 1
    else:
        dim = 0

    SpectralNorm.apply(module, name, n_power_iterations, dim, eps)

    return module
Ejemplo n.º 12
0
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        n_power_iterations: int = 1,
    ) -> None:
        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
        )

        SpectralNorm.apply(
            module=self,
            n_power_iterations=n_power_iterations,
            name="weight",
            dim=0,
            eps=1e-12
        )
Ejemplo n.º 13
0
def spectral_norm(module, use_spect=True):
    """use spectral normal layer to stable the training process"""
    if use_spect:
        return SpectralNorm(module)
    else:
        return module
Ejemplo n.º 14
0
class SmartSpectralNorm(SpectralNorm):
    def __init__(
        self,
        name="weight",
        n_power_iterations=1,
        dim=1,
        eps=1e-12,
        mode="",
        strict=True,
    ):
        super().__init__(name, n_power_iterations, dim, eps)
        assert validate_mode(mode)
        self.alpha, self.mode = parse_mode(mode)
        self.strict = strict
        self.dense = SpectralNorm(name, n_power_iterations, dim, eps)
        self.conv = ConvSpectralNorm(name, n_power_iterations, None, eps)

    @staticmethod
    def apply(module,
              name,
              n_power_iterations,
              dim,
              eps,
              mode="bug",
              strict=True):
        # conv spectral norm
        fn = SmartSpectralNorm(name, n_power_iterations, dim, eps, mode,
                               strict)
        weight = module._parameters[name]

        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        module.register_buffer(fn.name, weight.data)

        # for conv spectral norm
        setattr(module, fn.name + "_ux", None)
        # regular spectral norm
        height = weight.size(dim)
        u = normalize(weight.new_empty(height).normal_(0, 1),
                      dim=0,
                      eps=fn.eps)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a
        # buffer, which will cause weight to be included in the state dict
        # and also supports nn.init due to shared storage.
        module.register_buffer(fn.name + "_u", u)

        if fn.mode.startswith("learn"):
            alpha = weight.new_empty(()).fill_(1.0)
            module.register_parameter(fn.name + "_alpha",
                                      torch.nn.Parameter(alpha))
        module.register_forward_pre_hook(fn)
        return fn

    def init_ux(self, module, inputs):
        if getattr(module, self.name + "_ux") is None:
            delattr(module, self.name + "_ux")
            u = inputs[0][0][None]
            if is_transposed(module):
                u = module.forward(u)
            module.register_buffer(self.name + "_ux", u)  # first item in batch

    def compute_weight(self, module):
        weight = getattr(module, self.name + "_orig")
        weight_mat = weight
        if self.dim != 0:
            # permute dim to front
            weight_mat = weight_mat.permute(
                self.dim,
                *[d for d in range(weight_mat.dim()) if d != self.dim])
        height = weight_mat.size(0)
        weight_mat = weight_mat.reshape(height, -1)
        weight_dense, u = self.dense.compute_weight(module)
        v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
        sigma_dense = torch.dot(u, torch.matmul(weight_mat, v))

        weight_conv, ux = self.conv.compute_weight(module)
        conv_params = dict(
            padding=module.padding,
            stride=module.stride,
            dilation=module.dilation,
            groups=module.groups,
        )
        ux, vx = conv_power_iteration(weight, ux, **conv_params)
        sigma_conv = conv_sigma(weight, ux, vx, **conv_params)
        if self.mode == "fix":
            sigma = sigma_conv
            alpha = self.alpha if self.alpha is not None else 1.0
        elif self.mode == "bug":
            sigma = sigma_dense
            alpha = self.alpha if self.alpha is not None else 1.0
        elif self.mode == "bug/fix":
            alpha = (sigma_conv / sigma_dense).detach()  # alpha as bug
            sigma = sigma_conv  # scaled as fix
        elif self.mode == "fix/bug":
            alpha = (sigma_dense / sigma_conv).detach()  # alpha as fix
            sigma = sigma_dense  # scaled as bug
        elif self.mode == "learn/fix":
            alpha = getattr(module, self.name + "_alpha")
            sigma = sigma_conv
        elif self.mode == "learn/bug":
            alpha = getattr(module, self.name + "_alpha")
            sigma = sigma_dense
        else:  # self.mode == '':
            sigma = 1.0
            alpha = 1.0
        if self.strict:
            sigma = sigma
        else:
            sigma = max(1.0, sigma)
        weight = weight / sigma * alpha
        return weight, u, ux

    def remove(self, module):
        delattr(module, self.name + "_ux")
        super().remove(module)

    def __call__(self, module, inputs):
        if getattr(module, self.name + "_ux") is None:
            self.init_ux(module, inputs)
        if module.training:
            weight, u, ux = self.compute_weight(module)
            setattr(module, self.name, weight)
            setattr(module, self.name + "_ux", ux)
            setattr(module, self.name + "_u", u)
        else:
            r_g = getattr(module, self.name + "_orig").requires_grad
            getattr(module, self.name).detach_().requires_grad_(r_g)