示例#1
0
  def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Number of resblocks per stage
    self.G_depth = G_depth
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    self.bottom_width = bottom_width
    # Resolution of the output
    self.resolution = resolution
    # Kernel size?
    self.kernel_size = G_kernel_size
    # Attention?
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]


    # Which convs, batchnorms, and linear layers to use
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple G
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.dim_z if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)


    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    self.linear = self.which_linear(self.dim_z + self.shared_dim, self.arch['in_channels'][0] * (self.bottom_width **2))

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['in_channels'][index] if g_index==0 else self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] and g_index == (self.G_depth-1) else None))]
                       for g_index in range(self.G_depth)]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

    # Initialize weights. Optionally skip init for testing.
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    # If this is an EMA copy, no need for an optim, so just return now
    if no_optim:
      return
    self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
    if G_mixed_precision:
      print('Using fp16 adam in G...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
示例#2
0
    def __init__(self,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 decoder_skip_connection=True,
                 **kwargs):
        super(Unet_Discriminator, self).__init__()

        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16

        if self.resolution == 128:
            self.save_features = [0, 1, 2, 3, 4]
        elif self.resolution == 256:
            self.save_features = [0, 1, 2, 3, 4, 5]

        self.out_channel_multiplier = 1  #4
        # Architecture
        self.arch = D_unet_arch(
            self.ch,
            self.attention,
            out_channel_multiplier=self.out_channel_multiplier)[resolution]

        self.unconditional = kwargs["unconditional"]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)

            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []

        for index in range(len(self.arch['out_channels'])):

            if self.arch["downsample"][index]:
                self.blocks += [[
                    layers.DBlock(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(nn.AvgPool2d(2) if
                                    self.arch['downsample'][index] else None))
                ]]

            elif self.arch["upsample"][index]:
                upsample_function = (
                    functools.partial(F.interpolate,
                                      scale_factor=2,
                                      mode="nearest")  #mode=nearest is default
                    if self.arch['upsample'][index] else None)

                self.blocks += [[
                    layers.GBlock2(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        #which_bn=self.which_bn,
                        activation=self.activation,
                        upsample=upsample_function,
                        skip_connection=True)
                ]]

            # If attention on this block, attach it to the end
            attention_condition = index < 5
            if self.arch['attention'][self.arch['resolution'][
                    index]] and attention_condition:  #index < 5
                print('Adding attention layer in D at resolution %d' %
                      self.arch['resolution'][index])
                print("index = ", index)
                self.blocks[-1] += [
                    layers.Attention(self.arch['out_channels'][index],
                                     self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        last_layer = nn.Conv2d(self.ch * self.out_channel_multiplier,
                               1,
                               kernel_size=1)
        self.blocks.append(last_layer)
        #
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1],
                                        output_dim)

        self.linear_middle = self.which_linear(16 * self.ch, output_dim)
        # Embedding for projection discrimination
        #if not kwargs["agnostic_unet"] and not kwargs["unconditional"]:
        #    self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]+extra)
        if not kwargs["unconditional"]:
            self.embed_middle = self.which_embedding(self.n_classes,
                                                     16 * self.ch)
            self.embed = self.which_embedding(self.n_classes,
                                              self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        ###
        print("_____params______")
        for name, param in self.named_parameters():
            print(name, param.size())

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            self.optim = optim.Adam(params=self.parameters(),
                                    lr=self.lr,
                                    betas=(self.B1, self.B2),
                                    weight_decay=0,
                                    eps=self.adam_eps)
示例#3
0
  def __init__(self, D_ch=64, D_wide=True, D_depth=2, resolution=128,
               D_kernel_size=3, D_attn='64', n_classes=1000,
               num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
               D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
               SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
               D_init='ortho', skip_init=False, D_param='SN', **kwargs):
    super(Discriminator, self).__init__()
    # Width multiplier
    self.ch = D_ch
    # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
    self.D_wide = D_wide
    # How many resblocks per stage?
    self.D_depth = D_depth
    # Resolution
    self.resolution = resolution
    # Kernel size
    self.kernel_size = D_kernel_size
    # Attention?
    self.attention = D_attn
    # Number of classes
    self.n_classes = n_classes
    # Activation
    self.activation = D_activation
    # Initialization style
    self.init = D_init
    # Parameterization style
    self.D_param = D_param
    # Epsilon for Spectral Norm?
    self.SN_eps = SN_eps
    # Fp16?
    self.fp16 = D_fp16
    # Architecture
    self.arch = D_arch(self.ch, self.attention)[resolution]


    # Which convs, batchnorms, and linear layers to use
    # No option to turn off SN in D right now
    if self.D_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                          eps=self.SN_eps)
      self.which_embedding = functools.partial(layers.SNEmbedding,
                              num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                              eps=self.SN_eps)
    
    
    # Prepare model
    # Stem convolution
    self.input_conv = self.which_conv(3, self.arch['in_channels'][0])
    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[DBlock(in_channels=self.arch['in_channels'][index] if d_index==0 else self.arch['out_channels'][index],
                       out_channels=self.arch['out_channels'][index],
                       which_conv=self.which_conv,
                       wide=self.D_wide,
                       activation=self.activation,
                       preactivation=True,
                       downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] and d_index==0 else None))
                       for d_index in range(self.D_depth)]]
      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
                                             self.which_conv)]
    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
    # Linear output layer. The output dimension is typically 1, but may be
    # larger if we're e.g. turning this into a VAE with an inference output
    self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
    # Embedding for projection discrimination
    self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

    # Initialize weights
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
    if D_mixed_precision:
      print('Using fp16 adam in D...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                             betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                             betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
示例#4
0
    def __init__(self,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 attn_style='nl',
                 sched_version='default',
                 num_epochs=500,
                 arch=None,
                 use_dog_cnt=False,
                 **kwargs):
        super(Discriminator, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16

        self.use_dog_cnt = use_dog_cnt
        # Architecture
        if arch is None:
            arch = f'{resolution}'
        self.arch = D_arch(self.ch, self.attention)[arch]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)
            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)

        if attn_style == 'cbam':
            self.which_attn = layers.CBAM
        else:
            self.which_attn = layers.Attention

        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[
                layers.DBlock(
                    in_channels=self.arch['in_channels'][index],
                    out_channels=self.arch['out_channels'][index],
                    which_conv=self.which_conv,
                    wide=self.D_wide,
                    activation=self.activation,
                    preactivation=(index > 0),
                    downsample=(nn.AvgPool2d(2)
                                if self.arch['downsample'][index] else None))
            ]]
            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in D at resolution %d' %
                      self.arch['resolution'][index])
                self.blocks[-1] += [
                    self.which_attn(self.arch['out_channels'][index],
                                    self.which_conv)
                ]
        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1],
                                        output_dim)
        # Embedding for projection discrimination
        self.embed = self.which_embedding(self.n_classes,
                                          self.arch['out_channels'][-1])

        if self.use_dog_cnt:
            # Embedding for dog count projection discrimination
            self.embed_dog_cnt = self.which_embedding(
                4, self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps,
                                      amsgrad=kwargs['amsgrad'])
        else:
            self.optim = optim.Adam(params=self.parameters(),
                                    lr=self.lr,
                                    betas=(self.B1, self.B2),
                                    weight_decay=0,
                                    eps=self.adam_eps,
                                    amsgrad=kwargs['amsgrad'])
        # LR scheduling, left here for forward compatibility
        # self.lr_sched = {'itr' : 0}# if self.progressive else {}
        # self.j = 0
        if sched_version == 'default':
            self.lr_sched = None
        elif sched_version == 'cal_v0':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
                self.optim,
                T_max=num_epochs,
                eta_min=self.lr / 2,
                last_epoch=-1)
        elif sched_version == 'cal_v1':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
                self.optim,
                T_max=num_epochs,
                eta_min=self.lr / 4,
                last_epoch=-1)
        elif sched_version == 'cawr_v0':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optim, T_0=10, T_mult=2, eta_min=self.lr / 2)
        elif sched_version == 'cawr_v1':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optim, T_0=25, T_mult=2, eta_min=self.lr / 4)
        else:
            self.lr_sched = None
示例#5
0
    def __init__(self,
                 G_ch=64,
                 dim_z=128,
                 bottom_width=4,
                 resolution=128,
                 G_kernel_size=3,
                 G_attn='64',
                 n_classes=1000,
                 num_G_SVs=1,
                 num_G_SV_itrs=1,
                 G_shared=True,
                 shared_dim=0,
                 hier=False,
                 cross_replica=False,
                 mybn=False,
                 G_activation=nn.ReLU(inplace=False),
                 G_lr=5e-5,
                 G_B1=0.0,
                 G_B2=0.999,
                 adam_eps=1e-8,
                 BN_eps=1e-5,
                 SN_eps=1e-12,
                 G_mixed_precision=False,
                 G_fp16=False,
                 G_init='ortho',
                 skip_init=False,
                 no_optim=False,
                 G_param='SN',
                 norm_style='bn',
                 add_blur=False,
                 add_noise=False,
                 add_style=False,
                 style_mlp=6,
                 attn_style='nl',
                 no_conditional=False,
                 sched_version='default',
                 num_epochs=500,
                 arch=None,
                 skip_z=False,
                 use_dog_cnt=False,
                 dim_dog_cnt_z=32,
                 mix_style=False,
                 **kwargs):
        super(Generator, self).__init__()
        # Channel width mulitplier
        self.ch = G_ch
        # Dimensionality of the latent space
        self.dim_z = dim_z
        # The initial spatial dimensions
        self.bottom_width = bottom_width
        # Resolution of the output
        self.resolution = resolution
        # Kernel size?
        self.kernel_size = G_kernel_size
        # Attention?
        self.attention = G_attn
        # number of classes, for use in categorical conditional generation
        self.n_classes = n_classes
        # Use shared embeddings?
        self.G_shared = G_shared
        # Dimensionality of the shared embedding? Unused if not using G_shared
        self.shared_dim = shared_dim if shared_dim > 0 else dim_z
        # Hierarchical latent space?
        self.hier = hier
        # Cross replica batchnorm?
        self.cross_replica = cross_replica
        # Use my batchnorm?
        self.mybn = mybn
        # nonlinearity for residual blocks
        self.activation = G_activation
        # Initialization style
        self.init = G_init
        # Parameterization style
        self.G_param = G_param
        # Normalization style
        self.norm_style = norm_style
        # Normalization style
        self.add_blur = add_blur
        self.add_noise = add_noise
        self.add_style = add_style
        self.skip_z = skip_z
        self.use_dog_cnt = use_dog_cnt
        self.dim_dog_cnt_z = dim_dog_cnt_z
        self.mix_style = mix_style

        # Epsilon for BatchNorm?
        self.BN_eps = BN_eps
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # fp16?
        self.fp16 = G_fp16
        # Architecture dict
        if arch is None:
            arch = f'{resolution}'
        self.arch = G_arch(self.ch, self.attention)[arch]

        # If using hierarchical latents, adjust z
        if self.hier:
            # Number of places z slots into
            self.num_slots = len(self.arch['in_channels']) + 1
            self.z_chunk_size = (self.dim_z // self.num_slots)
            # Recalculate latent dimensionality for even splitting into chunks
            self.dim_z = self.z_chunk_size * self.num_slots
        else:
            self.num_slots = 1
            self.z_chunk_size = 0

        # Which convs, batchnorms, and linear layers to use
        if self.G_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_G_SVs,
                                                num_itrs=num_G_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_G_SVs,
                                                  num_itrs=num_G_SV_itrs,
                                                  eps=self.SN_eps)
        else:
            self.which_conv = functools.partial(nn.Conv2d,
                                                kernel_size=3,
                                                padding=1)
            self.which_linear = nn.Linear

        if attn_style == 'cbam':
            self.which_attn = layers.CBAM
        else:
            self.which_attn = layers.Attention

        # We use a non-spectral-normed embedding here regardless;
        # For some reason applying SN to G's embedding seems to randomly cripple G
        self.which_embedding = nn.Embedding
        bn_linear = (functools.partial(self.which_linear, bias=False)
                     if self.G_shared else self.which_embedding)
        input_size = self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes
        if self.G_shared and use_dog_cnt:
            input_size += dim_dog_cnt_z
        self.which_bn = functools.partial(
            layers.ccbn,
            which_linear=bn_linear,
            cross_replica=self.cross_replica,
            mybn=self.mybn,
            input_size=input_size,
            norm_style=self.norm_style,
            eps=self.BN_eps,
            style_linear=self.which_linear,
            dim_z=self.dim_z,
            no_conditional=no_conditional,
            skip_z=self.skip_z,
            use_dog_cnt=use_dog_cnt,
            g_shared=G_shared,
        )

        # Prepare model
        # If not using shared embeddings, self.shared is just a passthrough
        self.shared = (self.which_embedding(n_classes, self.shared_dim)
                       if G_shared else layers.identity())

        self.dog_cnt_shared = (self.which_embedding(4, self.dim_dog_cnt_z)
                               if G_shared else layers.identity())
        # First linear layer
        self.linear = self.which_linear(
            self.dim_z // self.num_slots,
            self.arch['in_channels'][0] * (self.bottom_width**2))

        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        # while the inner loop is over a given block
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[
                layers.GBlock(
                    in_channels=self.arch['in_channels'][index],
                    out_channels=self.arch['out_channels'][index],
                    which_conv=self.which_conv,
                    which_bn=self.which_bn,
                    activation=self.activation,
                    upsample=(functools.partial(F.interpolate, scale_factor=2)
                              if self.arch['upsample'][index] else None),
                    add_blur=add_blur,
                    add_noise=add_noise,
                )
            ]]

            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in G at resolution %d' %
                      self.arch['resolution'][index])
                self.blocks[-1] += [
                    self.which_attn(self.arch['out_channels'][index],
                                    self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        # output layer: batchnorm-relu-conv.
        # Consider using a non-spectral conv here
        self.output_layer = nn.Sequential(
            layers.bn(self.arch['out_channels'][-1],
                      cross_replica=self.cross_replica,
                      mybn=self.mybn), self.activation,
            self.which_conv(self.arch['out_channels'][-1], 3))

        if self.add_style:
            # layers = [PixelNorm()]
            style_layers = []
            for i in range(style_mlp):
                style_layers.append(
                    layers.StyleLayer(self.dim_z, self.which_linear,
                                      self.activation))

            self.style = nn.Sequential(*style_layers)

        # Initialize weights. Optionally skip init for testing.
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        # If this is an EMA copy, no need for an optim, so just return now
        if no_optim:
            return
        self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
        if G_mixed_precision:
            print('Using fp16 adam in G...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps,
                                      amsgrad=kwargs['amsgrad'])
        else:
            self.optim = optim.Adam(params=self.parameters(),
                                    lr=self.lr,
                                    betas=(self.B1, self.B2),
                                    weight_decay=0,
                                    eps=self.adam_eps,
                                    amsgrad=kwargs['amsgrad'])

        # LR scheduling, left here for forward compatibility
        # self.lr_sched = {'itr' : 0}# if self.progressive else {}
        # self.j = 0

        if sched_version == 'default':
            self.lr_sched = None
        elif sched_version == 'cal_v0':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
                self.optim,
                T_max=num_epochs,
                eta_min=self.lr / 2,
                last_epoch=-1)
        elif sched_version == 'cal_v1':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
                self.optim,
                T_max=num_epochs,
                eta_min=self.lr / 4,
                last_epoch=-1)
        elif sched_version == 'cawr_v0':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optim, T_0=10, T_mult=2, eta_min=self.lr / 2)
        elif sched_version == 'cawr_v1':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optim, T_0=25, T_mult=2, eta_min=self.lr / 4)
        else:
            self.lr_sched = None
示例#6
0
    def __init__(self,
                 D=None,
                 is_enc_shared=True,
                 dim_z=128,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 optimizer_type='adam',
                 weight_decay=5e-4,
                 dataset_channel=3,
                 **kwargs):
        super(Encoder, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # weight decay
        self.weight_decay = weight_decay
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16
        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)
            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        if not is_enc_shared:
            self.arch['in_channels'][0] = dataset_channel
            self.blocks = []
            for index in range(len(self.arch['out_channels'])):
                self.blocks += [[
                    layers.DBlock(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(nn.AvgPool2d(2) if
                                    self.arch['downsample'][index] else None))
                ]]
                # If attention on this block, attach it to the end
                if self.arch['attention'][self.arch['resolution'][index]]:
                    print('Adding attention layer in D at resolution %d' %
                          self.arch['resolution'][index])
                    self.blocks[-1] += [
                        layers.Attention(self.arch['out_channels'][index],
                                         self.which_conv)
                    ]
            self.blocks = nn.ModuleList(
                [nn.ModuleList(block) for block in self.blocks])
        else:
            self.blocks = D.blocks

        self.blocks = self.blocks[:-1]

        self.blocks2 = []
        self.blocks2 += [[
            layers.DBlock(in_channels=self.arch['in_channels'][-1],
                          out_channels=self.arch['out_channels'][-1],
                          which_conv=self.which_conv,
                          wide=self.D_wide,
                          activation=self.activation,
                          preactivation=(True),
                          downsample=(None))
        ]]

        dim_reduction_enc = 2
        for index in range(dim_reduction_enc):
            self.blocks2 += [[
                layers.DBlock(in_channels=self.arch['in_channels'][-1],
                              out_channels=self.arch['out_channels'][-1],
                              which_conv=self.which_conv,
                              wide=self.D_wide,
                              activation=self.activation,
                              preactivation=(True),
                              downsample=(nn.AvgPool2d(2)))
            ]]
        self.blocks2 = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks2])

        #output linear VAE
        reduction = 2**(dim_reduction_enc +
                        np.sum(np.array(self.arch['downsample'])))
        o_res_dim = resolution // reduction
        self.input_linear_dim = self.arch['out_channels'][-1] * o_res_dim**2
        print(reduction)
        print(o_res_dim)
        self.linear_mu1 = self.which_linear(self.input_linear_dim,
                                            self.input_linear_dim // 2)
        self.linear_lv1 = self.which_linear(self.input_linear_dim,
                                            self.input_linear_dim // 2)
        self.linear_mu2 = self.which_linear(self.input_linear_dim // 2, dim_z)
        self.linear_lv2 = self.which_linear(self.input_linear_dim // 2, dim_z)

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            if optimizer_type == 'adam':
                self.optim = optim.Adam(params=self.parameters(),
                                        lr=self.lr,
                                        betas=(self.B1, self.B2),
                                        weight_decay=0,
                                        eps=self.adam_eps)
            elif optimizer_type == 'radam':
                self.optim = optimizers.RAdam(params=self.parameters(),
                                              lr=self.lr,
                                              betas=(self.B1, self.B2),
                                              weight_decay=self.weight_decay,
                                              eps=self.adam_eps)
            elif optimizer_type == 'ranger':
                self.optim = optimizers.Ranger(params=self.parameters(),
                                               lr=self.lr,
                                               betas=(self.B1, self.B2),
                                               weight_decay=self.weight_decay,
                                               eps=self.adam_eps)
示例#7
0
    def __init__(self,
                 G_batch_size=100,
                 batch_size=100,
                 dim_z=128,
                 n_classes=1000,
                 sigma=0.5,
                 is_y_uniform=False,
                 prior_type='default',
                 G_fp16=False,
                 arch_aux=0,
                 G_param='SN',
                 D_param='SN',
                 device='cuda',
                 P_lr=2e-4,
                 G_lr=5e-5,
                 G_B1=0.0,
                 G_B2=0.999,
                 adam_eps=1e-8,
                 num_G_SVs=1,
                 num_G_SV_itrs=1,
                 SN_eps=1e-12,
                 G_mixed_precision=False,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 G_activation=nn.ReLU(inplace=True),
                 GMM_init='ortho',
                 sharpen=1.0,
                 optimizer_type='adam',
                 weight_decay=5e-4,
                 **kwargs):

        super(Prior, self).__init__()
        dtype = torch.float16 if G_fp16 else torch.float32

        self.dim_z = dim_z
        self.sigma = sigma
        self.n_classes = n_classes
        self.prior_type = prior_type
        self.is_y_uniform = is_y_uniform
        self.bs = max(G_batch_size, batch_size)
        self.sharpen = sharpen
        self.weight_decay = weight_decay

        import utils
        self.z_, self.y_ = utils.prepare_z_y(self.bs,
                                             dim_z,
                                             n_classes,
                                             device=device,
                                             fp16=G_fp16)
        self.eps_ = self.z_

        which_embedding = nn.Embedding
        self.sample_ = self.sample_default
        self.obtain_latent_from_z_y = self.obtain_latent_from_z_y_default
        self.latent_classification = self.latent_classification_default
        G_activation = nn.ReLU(inplace=True)

        if prior_type == 'default':
            self.y_aux = (
                1 / self.n_classes *
                torch.arange(self.n_classes, dtype=torch.float).reshape(
                    1, n_classes)).cuda()

        elif prior_type == 'aux':
            if G_param == 'SN':
                which_linear = functools.partial(SNLinear,
                                                 num_svs=num_G_SVs,
                                                 num_itrs=num_G_SV_itrs,
                                                 eps=SN_eps)
            else:
                which_linear = nn.Linear

            if arch_aux == 0:
                self.gen_linear = which_linear(2 * dim_z, dim_z)
                latent_classification = nn.Sequential(
                    which_linear(dim_z, dim_z), G_activation,
                    which_linear(dim_z, n_classes), nn.Softmax())

            elif arch_aux == 1:
                self.gen_linear = nn.Sequential(which_linear(2 * dim_z, dim_z),
                                                nn.Tanh(True))
                latent_classification = nn.Sequential(
                    which_linear(dim_z, dim_z), G_activation,
                    which_linear(dim_z, n_classes), nn.Softmax())

            self.first_embedding = which_embedding(n_classes, dim_z)
            self.sample_ = self.sample_aux
            self.latent_classification = latent_classification
            self.obtain_latent_from_z_y = self.obtain_latent_from_z_y_aux

        elif prior_type == 'GMM':

            self.init = GMM_init
            self.mu_c = nn.Parameter(data=torch.zeros((n_classes, dim_z),
                                                      dtype=dtype),
                                     requires_grad=True)
            self.lv_c = nn.Parameter(data=torch.ones((n_classes, dim_z),
                                                     dtype=dtype),
                                     requires_grad=True)
            self.phi_c = nn.Parameter(data=self.sigma *
                                      torch.ones(n_classes, dtype=dtype),
                                      requires_grad=False)
            self.sample_ = self.sample_from_gmm
            self.latent_classification = self.gmm_membressy2
            self.obtain_latent_from_z_y = self.obtain_latent_from_z_y_gmm

            if self.init == 'ortho':
                init.orthogonal_(self.mu_c)
            elif self.init == 'N02':
                init.normal_(self.mu_c, 0, 0.02)
            elif self.init in ['glorot', 'xavier']:
                init.xavier_uniform_(self.mu_c)
            elif self.init == 'mu_sep':
                extra_dim = dim_z % n_classes
                reap_dim = dim_z // n_classes
                mu_init = 1
                gmm_mu = mu_init * (1 + self.sigma) * np.hstack(
                    (np.eye(n_classes).repeat(
                        reap_dim, 1), np.zeros((n_classes, extra_dim))))
                del self.mu_c
                self.mu_c = nn.Parameter(data=torch.tensor(gmm_mu,
                                                           dtype=dtype),
                                         requires_grad=True)

        if prior_type == 'aux' or prior_type == 'GMM':
            self.lr, self.B1, self.B2, self.adam_eps = P_lr, G_B1, G_B2, adam_eps
            if G_mixed_precision:
                print('Using fp16 adam in Prior...')
                self.optim = utils.Adam16(params=self.parameters(),
                                          lr=self.lr,
                                          betas=(self.B1, self.B2),
                                          weight_decay=0,
                                          eps=self.adam_eps)
            else:
                if optimizer_type == 'adam':
                    self.optim = optim.Adam(params=self.parameters(),
                                            lr=self.lr,
                                            betas=(self.B1, self.B2),
                                            weight_decay=0,
                                            eps=self.adam_eps)
                elif optimizer_type == 'radam':
                    self.optim = optimizers.RAdam(
                        params=self.parameters(),
                        lr=self.lr,
                        betas=(self.B1, self.B2),
                        weight_decay=self.weight_decay,
                        eps=self.adam_eps)

                elif optimizer_type == 'ranger':
                    self.optim = optimizers.Ranger(
                        params=self.parameters(),
                        lr=self.lr,
                        betas=(self.B1, self.B2),
                        weight_decay=self.weight_decay,
                        eps=self.adam_eps)

        if is_y_uniform:
            del self.y_
            self.y_ = torch.arange(n_classes).repeat(
                self.bs // n_classes, ).to(
                    device, device, torch.float16 if G_fp16 else torch.float32)
示例#8
0
  def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    """
    utils中有这些参数的定义,通过parase和vars方法封装这些参数
    看一下模型到底是咋样
    G_ch 生成模型的信道 默认64,指的是一种模型机构的总和,64可解析为如下结构
    ch = 64
    arch[128] = {'in_channels' :  [ch * item for item in [16, 16, 8, 4, 2]],
                'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
                'upsample' : [True] * 5,
                'resolution' : [8, 16, 32, 64, 128],
                'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                                for i in range(3,8)}}
    dim_z 噪声的维度,默认为128

    """
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    ## TODO 暂时不理解这个的主要作用
    self.bottom_width = bottom_width
    # Resolution of the output
    ## 表示选择的结构
    self.resolution = resolution
    # Kernel size?
    ## TODO 这个不是外部参数导入的, 也么有用到
    self.kernel_size = G_kernel_size
    # Attention?
    ## 只是做了个中介,转手就到了self.arch中选择,最后会在attention的结构中得到解析
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    ## 默认False
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    ## https://zhuanlan.zhihu.com/p/68081406
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]

    # If using hierarchical latents, adjust z
    if self.hier:
      # Number of places z slots into
      self.num_slots = len(self.arch['in_channels']) + 1
      self.z_chunk_size = (self.dim_z // self.num_slots)
      # Recalculate latent dimensionality for even splitting into chunks
      self.dim_z = self.z_chunk_size *  self.num_slots
    else:
      self.num_slots = 1
      self.z_chunk_size = 0

    # Which convs, batchnorms, and linear layers to use
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple G
    ## *** fluid.dygraph.Embedding == nn.Embedding
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)

    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    self.linear = self.which_linear(self.dim_z // self.num_slots,
                                    self.arch['in_channels'][0] * (self.bottom_width **2))

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] else None))]]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

    # Initialize weights. Optionally skip init for testing.
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    # If this is an EMA copy, no need for an optim, so just return now
    if no_optim:
      return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
    if G_mixed_precision:
      print('Using fp16 adam in G...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)