示例#1
0
 def __init__(self, in_channels, out_channels):
     super().__init__()
     self.relu = nn.ReLU(inplace=True)
     self.spectral_norm_conv2d1 = spectral_norm_conv2d(
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=3,
         stride=1,
         padding=1,
     )
     self.spectral_norm_conv2d2 = spectral_norm_conv2d(
         in_channels=out_channels,
         out_channels=out_channels,
         kernel_size=3,
         stride=1,
         padding=1,
     )
     self.down_sample = nn.AvgPool2d(2)
     self.channel_mismatch = False
     if in_channels != out_channels:
         self.channel_mismatch = True
     self.spectral_norm_conv2d0 = spectral_norm_conv2d(
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=1,
         stride=1,
         padding=0,
     )
示例#2
0
 def __init__(self, in_channels, out_channels, num_classes):
     super(GenBlock, self).__init__()
     self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
     self.relu = nn.ReLU(inplace=True)
     self.spectral_norm_conv2d1 = spectral_norm_conv2d(
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=3,
         stride=1,
         padding=1,
     )
     self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
     self.spectral_norm_conv2d2 = spectral_norm_conv2d(
         in_channels=out_channels,
         out_channels=out_channels,
         kernel_size=3,
         stride=1,
         padding=1,
     )
     self.spectral_norm_conv2d0 = spectral_norm_conv2d(
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=1,
         stride=1,
         padding=0,
     )
示例#3
0
    def __init__(self, z_dim, g_conv_dim, num_classes):
        super(Generator, self).__init__()

        self.z_dim = z_dim
        self.g_conv_dim = g_conv_dim
        self.spectral_norm_linear0 = spectral_norm_linear(
            in_features=z_dim, out_features=g_conv_dim * 16 * 4 * 4)
        self.block1 = GenBlock(g_conv_dim * 16, g_conv_dim * 16, num_classes)
        self.block2 = GenBlock(g_conv_dim * 16, g_conv_dim * 8, num_classes)
        self.block3 = GenBlock(g_conv_dim * 8, g_conv_dim * 4, num_classes)
        self.self_attn = SelfAttentionModule(g_conv_dim * 4)
        self.block4 = GenBlock(g_conv_dim * 4, g_conv_dim * 2, num_classes)
        self.block5 = GenBlock(g_conv_dim * 2, g_conv_dim, num_classes)
        self.bn = nn.BatchNorm2d(g_conv_dim,
                                 eps=1e-5,
                                 momentum=0.0001,
                                 affine=True)
        self.relu = nn.ReLU(inplace=True)
        self.spectral_norm_conv2d1 = spectral_norm_conv2d(
            in_channels=g_conv_dim,
            out_channels=3,
            kernel_size=3,
            stride=1,
            padding=1)
        self.tanh = nn.Tanh()

        self.apply(init_weights)
示例#4
0
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.spectral_norm_conv1x1_theta = spectral_norm_conv2d(
            in_channels=in_channels,
            out_channels=in_channels // 8,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.spectral_norm_conv1x1_phi = spectral_norm_conv2d(
            in_channels=in_channels,
            out_channels=in_channels // 8,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.spectral_norm_conv1x1_g = spectral_norm_conv2d(
            in_channels=in_channels,
            out_channels=in_channels // 2,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.spectral_norm_conv1x1_attn = spectral_norm_conv2d(
            in_channels=in_channels // 2,
            out_channels=in_channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
        self.softmax = nn.Softmax(
            dim=-1)  # TODO: use log_softmax?, Check dim maybe it should be 1

        self.sigma = nn.Parameter(torch.zeros(1), requires_grad=True)