Exemple #1
0
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]] = (3, 3),
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Union[int, Tuple[int, int]] = 0,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        n_power_iterations: int = 1,
    ) -> None:
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )

        SpectralNorm.apply(
            module=self,
            n_power_iterations=n_power_iterations,
            name="weight",
            dim=0,
            eps=1e-12
        )
    def __init__(self, mode, cnn_chunk, out_c, k, h_dim):
        super().__init__()

        self.conv = cnn_chunk
        self.k = k
        self.mode = mode

        # Discriminator
        self.disc = nn.Sequential(
            nn.Linear(out_c, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1),
        )
        if 'sn' in mode:
            for module in self.disc.modules():
                if isinstance(module, nn.Linear):
                    SpectralNorm.apply(module,
                                       'weight',
                                       n_power_iterations=1,
                                       eps=1e-12,
                                       dim=0)
        else:
            assert mode == 'wgan-gp'
Exemple #3
0
def spectral_norm(module,
                  name="weight",
                  n_power_iterations=1,
                  eps=1e-12,
                  dim=None):
    if is_conv(module):
        ConvSpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    else:
        if dim is None:
            dim = 0
        SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module
Exemple #4
0
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
    """Copied from https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html"""
    # Extra check of hasattr for classes with no 'weight'
    if hasattr(module, name):
        if dim is None:
            if isinstance(module, (torch.nn.ConvTranspose1d,
                                   torch.nn.ConvTranspose2d,
                                   torch.nn.ConvTranspose3d)):
                dim = 1
            else:
                dim = 0
        SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module
Exemple #5
0
def spectral_norm(module,
                  name='weight',
                  apply_to=['conv2d'],
                  n_power_iterations=1,
                  eps=1e-12):
    # Apply only to modules in apply_to list
    module_name = module.__class__.__name__.lower()
    if module_name not in apply_to or 'adaptive' in module_name:
        return module

    if isinstance(module, nn.ConvTranspose2d):
        dim = 1
    else:
        dim = 0

    SpectralNorm.apply(module, name, n_power_iterations, dim, eps)

    return module
Exemple #6
0
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        n_power_iterations: int = 1,
    ) -> None:
        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
        )

        SpectralNorm.apply(
            module=self,
            n_power_iterations=n_power_iterations,
            name="weight",
            dim=0,
            eps=1e-12
        )