Beispiel #1
0
    def __init__(self,
                 D=None,
                 is_enc_shared=True,
                 dim_z=128,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 optimizer_type='adam',
                 weight_decay=5e-4,
                 dataset_channel=3,
                 **kwargs):
        super(Encoder, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # weight decay
        self.weight_decay = weight_decay
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16
        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)
            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        if not is_enc_shared:
            self.arch['in_channels'][0] = dataset_channel
            self.blocks = []
            for index in range(len(self.arch['out_channels'])):
                self.blocks += [[
                    layers.DBlock(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(nn.AvgPool2d(2) if
                                    self.arch['downsample'][index] else None))
                ]]
                # If attention on this block, attach it to the end
                if self.arch['attention'][self.arch['resolution'][index]]:
                    print('Adding attention layer in D at resolution %d' %
                          self.arch['resolution'][index])
                    self.blocks[-1] += [
                        layers.Attention(self.arch['out_channels'][index],
                                         self.which_conv)
                    ]
            self.blocks = nn.ModuleList(
                [nn.ModuleList(block) for block in self.blocks])
        else:
            self.blocks = D.blocks

        self.blocks = self.blocks[:-1]

        self.blocks2 = []
        self.blocks2 += [[
            layers.DBlock(in_channels=self.arch['in_channels'][-1],
                          out_channels=self.arch['out_channels'][-1],
                          which_conv=self.which_conv,
                          wide=self.D_wide,
                          activation=self.activation,
                          preactivation=(True),
                          downsample=(None))
        ]]

        dim_reduction_enc = 2
        for index in range(dim_reduction_enc):
            self.blocks2 += [[
                layers.DBlock(in_channels=self.arch['in_channels'][-1],
                              out_channels=self.arch['out_channels'][-1],
                              which_conv=self.which_conv,
                              wide=self.D_wide,
                              activation=self.activation,
                              preactivation=(True),
                              downsample=(nn.AvgPool2d(2)))
            ]]
        self.blocks2 = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks2])

        #output linear VAE
        reduction = 2**(dim_reduction_enc +
                        np.sum(np.array(self.arch['downsample'])))
        o_res_dim = resolution // reduction
        self.input_linear_dim = self.arch['out_channels'][-1] * o_res_dim**2
        print(reduction)
        print(o_res_dim)
        self.linear_mu1 = self.which_linear(self.input_linear_dim,
                                            self.input_linear_dim // 2)
        self.linear_lv1 = self.which_linear(self.input_linear_dim,
                                            self.input_linear_dim // 2)
        self.linear_mu2 = self.which_linear(self.input_linear_dim // 2, dim_z)
        self.linear_lv2 = self.which_linear(self.input_linear_dim // 2, dim_z)

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            if optimizer_type == 'adam':
                self.optim = optim.Adam(params=self.parameters(),
                                        lr=self.lr,
                                        betas=(self.B1, self.B2),
                                        weight_decay=0,
                                        eps=self.adam_eps)
            elif optimizer_type == 'radam':
                self.optim = optimizers.RAdam(params=self.parameters(),
                                              lr=self.lr,
                                              betas=(self.B1, self.B2),
                                              weight_decay=self.weight_decay,
                                              eps=self.adam_eps)
            elif optimizer_type == 'ranger':
                self.optim = optimizers.Ranger(params=self.parameters(),
                                               lr=self.lr,
                                               betas=(self.B1, self.B2),
                                               weight_decay=self.weight_decay,
                                               eps=self.adam_eps)
Beispiel #2
0
    def __init__(self, D_ch=64, D_wide=True, resolution=128,
                 D_kernel_size=3, D_attn='64', n_classes=1000,
                 num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
                 SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
                 D_init='ortho', skip_init=False, D_param='SN', **kwargs):
        super(Discriminator, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16
        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                kernel_size=3, padding=1,
                                num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                                eps=self.SN_eps)
            self.which_embedding = functools.partial(layers.SNEmbedding,
                                    num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                                    eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             wide=self.D_wide,
                             activation=self.activation,
                             preactivation=(index > 0),
                             downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
                self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
                                                     self.which_conv)]
        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
        # Embedding for projection discrimination
        self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                                   betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
        else:
            self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                                   betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
Beispiel #3
0
    def __init__(self,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 decoder_skip_connection=True,
                 **kwargs):
        super(Unet_Discriminator, self).__init__()

        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16

        if self.resolution == 128:
            self.save_features = [0, 1, 2, 3, 4]
        elif self.resolution == 256:
            self.save_features = [0, 1, 2, 3, 4, 5]

        self.out_channel_multiplier = 1  #4
        # Architecture
        self.arch = D_unet_arch(
            self.ch,
            self.attention,
            out_channel_multiplier=self.out_channel_multiplier)[resolution]

        self.unconditional = kwargs["unconditional"]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)

            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []

        for index in range(len(self.arch['out_channels'])):

            if self.arch["downsample"][index]:
                self.blocks += [[
                    layers.DBlock(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(nn.AvgPool2d(2) if
                                    self.arch['downsample'][index] else None))
                ]]

            elif self.arch["upsample"][index]:
                upsample_function = (
                    functools.partial(F.interpolate,
                                      scale_factor=2,
                                      mode="nearest")  #mode=nearest is default
                    if self.arch['upsample'][index] else None)

                self.blocks += [[
                    layers.GBlock2(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        #which_bn=self.which_bn,
                        activation=self.activation,
                        upsample=upsample_function,
                        skip_connection=True)
                ]]

            # If attention on this block, attach it to the end
            attention_condition = index < 5
            if self.arch['attention'][self.arch['resolution'][
                    index]] and attention_condition:  #index < 5
                print('Adding attention layer in D at resolution %d' %
                      self.arch['resolution'][index])
                print("index = ", index)
                self.blocks[-1] += [
                    layers.Attention(self.arch['out_channels'][index],
                                     self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        last_layer = nn.Conv2d(self.ch * self.out_channel_multiplier,
                               1,
                               kernel_size=1)
        self.blocks.append(last_layer)
        #
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1],
                                        output_dim)

        self.linear_middle = self.which_linear(16 * self.ch, output_dim)
        # Embedding for projection discrimination
        #if not kwargs["agnostic_unet"] and not kwargs["unconditional"]:
        #    self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]+extra)
        if not kwargs["unconditional"]:
            self.embed_middle = self.which_embedding(self.n_classes,
                                                     16 * self.ch)
            self.embed = self.which_embedding(self.n_classes,
                                              self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        ###
        print("_____params______")
        for name, param in self.named_parameters():
            print(name, param.size())

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            self.optim = optim.Adam(params=self.parameters(),
                                    lr=self.lr,
                                    betas=(self.B1, self.B2),
                                    weight_decay=0,
                                    eps=self.adam_eps)