Пример #1
0
  def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Number of resblocks per stage
    self.G_depth = G_depth
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    self.bottom_width = bottom_width
    # Resolution of the output
    self.resolution = resolution
    # Kernel size?
    self.kernel_size = G_kernel_size
    # Attention?
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]


    # Which convs, batchnorms, and linear layers to use
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple G
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.dim_z if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)


    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    self.linear = self.which_linear(self.dim_z + self.shared_dim, self.arch['in_channels'][0] * (self.bottom_width **2))

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['in_channels'][index] if g_index==0 else self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] and g_index == (self.G_depth-1) else None))]
                       for g_index in range(self.G_depth)]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

    # Initialize weights. Optionally skip init for testing.
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    # If this is an EMA copy, no need for an optim, so just return now
    if no_optim:
      return
    self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
    if G_mixed_precision:
      print('Using fp16 adam in G...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
Пример #2
0
  def __init__(self, D_ch=64, D_wide=True, D_depth=2, resolution=128,
               D_kernel_size=3, D_attn='64', n_classes=1000,
               num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
               D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
               SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
               D_init='ortho', skip_init=False, D_param='SN', **kwargs):
    super(Discriminator, self).__init__()
    # Width multiplier
    self.ch = D_ch
    # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
    self.D_wide = D_wide
    # How many resblocks per stage?
    self.D_depth = D_depth
    # Resolution
    self.resolution = resolution
    # Kernel size
    self.kernel_size = D_kernel_size
    # Attention?
    self.attention = D_attn
    # Number of classes
    self.n_classes = n_classes
    # Activation
    self.activation = D_activation
    # Initialization style
    self.init = D_init
    # Parameterization style
    self.D_param = D_param
    # Epsilon for Spectral Norm?
    self.SN_eps = SN_eps
    # Fp16?
    self.fp16 = D_fp16
    # Architecture
    self.arch = D_arch(self.ch, self.attention)[resolution]


    # Which convs, batchnorms, and linear layers to use
    # No option to turn off SN in D right now
    if self.D_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                          eps=self.SN_eps)
      self.which_embedding = functools.partial(layers.SNEmbedding,
                              num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                              eps=self.SN_eps)
    
    
    # Prepare model
    # Stem convolution
    self.input_conv = self.which_conv(3, self.arch['in_channels'][0])
    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[DBlock(in_channels=self.arch['in_channels'][index] if d_index==0 else self.arch['out_channels'][index],
                       out_channels=self.arch['out_channels'][index],
                       which_conv=self.which_conv,
                       wide=self.D_wide,
                       activation=self.activation,
                       preactivation=True,
                       downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] and d_index==0 else None))
                       for d_index in range(self.D_depth)]]
      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
                                             self.which_conv)]
    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
    # Linear output layer. The output dimension is typically 1, but may be
    # larger if we're e.g. turning this into a VAE with an inference output
    self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
    # Embedding for projection discrimination
    self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

    # Initialize weights
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
    if D_mixed_precision:
      print('Using fp16 adam in D...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                             betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                             betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
Пример #3
0
def attention(inputs,
              kq_depth=None,
              max_relative_position=64,
              batch_norm=False,
              bn_momentum=0.99,
              bn_type='standard',
              **kwargs):
    """Construct a residual attention block.

  Args:
    inputs:                 [batch_size, seq_length, features] input sequence
    kq_depth:               Key-query feature depth
    max_relative_position:  Max relative position to differentiate w/ its own parameter

  Returns:
    output sequence
  """

    # flow through variable current
    current = inputs

    current_depth = current.shape[-1]
    if kq_depth is None:
        kq_depth = current_depth

    # key - who am I?
    key = tf.keras.layers.Conv1D(filters=kq_depth,
                                 kernel_size=1,
                                 padding='same',
                                 kernel_initializer='he_normal')(current)

    # query - what am I looking for?
    query = tf.keras.layers.Conv1D(filters=kq_depth,
                                   kernel_size=1,
                                   padding='same',
                                   kernel_initializer='he_normal')(current)

    # value - what do I have to say?
    value = tf.keras.layers.Conv1D(
        filters=current_depth,
        kernel_size=1,
        padding='same',
        kernel_initializer='he_normal',
    )(current)

    # apply layer
    z = layers.Attention(max_relative_position=max_relative_position)(
        [query, value, key])

    # batch norm
    if batch_norm:
        if bn_type == 'sync':
            bn_layer = tf.keras.layers.experimental.SyncBatchNormalization
        else:
            bn_layer = tf.keras.layers.BatchNormalization
        z = bn_layer(momentum=bn_momentum, gamma_initializer='zeros')(z)

    # residual add
    current = tf.keras.layers.Add()([current, z])

    return current
Пример #4
0
    def __init__(self,
                 D=None,
                 is_enc_shared=True,
                 dim_z=128,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 optimizer_type='adam',
                 weight_decay=5e-4,
                 dataset_channel=3,
                 **kwargs):
        super(Encoder, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # weight decay
        self.weight_decay = weight_decay
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16
        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)
            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        if not is_enc_shared:
            self.arch['in_channels'][0] = dataset_channel
            self.blocks = []
            for index in range(len(self.arch['out_channels'])):
                self.blocks += [[
                    layers.DBlock(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(nn.AvgPool2d(2) if
                                    self.arch['downsample'][index] else None))
                ]]
                # If attention on this block, attach it to the end
                if self.arch['attention'][self.arch['resolution'][index]]:
                    print('Adding attention layer in D at resolution %d' %
                          self.arch['resolution'][index])
                    self.blocks[-1] += [
                        layers.Attention(self.arch['out_channels'][index],
                                         self.which_conv)
                    ]
            self.blocks = nn.ModuleList(
                [nn.ModuleList(block) for block in self.blocks])
        else:
            self.blocks = D.blocks

        self.blocks = self.blocks[:-1]

        self.blocks2 = []
        self.blocks2 += [[
            layers.DBlock(in_channels=self.arch['in_channels'][-1],
                          out_channels=self.arch['out_channels'][-1],
                          which_conv=self.which_conv,
                          wide=self.D_wide,
                          activation=self.activation,
                          preactivation=(True),
                          downsample=(None))
        ]]

        dim_reduction_enc = 2
        for index in range(dim_reduction_enc):
            self.blocks2 += [[
                layers.DBlock(in_channels=self.arch['in_channels'][-1],
                              out_channels=self.arch['out_channels'][-1],
                              which_conv=self.which_conv,
                              wide=self.D_wide,
                              activation=self.activation,
                              preactivation=(True),
                              downsample=(nn.AvgPool2d(2)))
            ]]
        self.blocks2 = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks2])

        #output linear VAE
        reduction = 2**(dim_reduction_enc +
                        np.sum(np.array(self.arch['downsample'])))
        o_res_dim = resolution // reduction
        self.input_linear_dim = self.arch['out_channels'][-1] * o_res_dim**2
        print(reduction)
        print(o_res_dim)
        self.linear_mu1 = self.which_linear(self.input_linear_dim,
                                            self.input_linear_dim // 2)
        self.linear_lv1 = self.which_linear(self.input_linear_dim,
                                            self.input_linear_dim // 2)
        self.linear_mu2 = self.which_linear(self.input_linear_dim // 2, dim_z)
        self.linear_lv2 = self.which_linear(self.input_linear_dim // 2, dim_z)

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            if optimizer_type == 'adam':
                self.optim = optim.Adam(params=self.parameters(),
                                        lr=self.lr,
                                        betas=(self.B1, self.B2),
                                        weight_decay=0,
                                        eps=self.adam_eps)
            elif optimizer_type == 'radam':
                self.optim = optimizers.RAdam(params=self.parameters(),
                                              lr=self.lr,
                                              betas=(self.B1, self.B2),
                                              weight_decay=self.weight_decay,
                                              eps=self.adam_eps)
            elif optimizer_type == 'ranger':
                self.optim = optimizers.Ranger(params=self.parameters(),
                                               lr=self.lr,
                                               betas=(self.B1, self.B2),
                                               weight_decay=self.weight_decay,
                                               eps=self.adam_eps)
Пример #5
0
    def __init__(self,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 decoder_skip_connection=True,
                 **kwargs):
        super(Unet_Discriminator, self).__init__()

        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16

        if self.resolution == 128:
            self.save_features = [0, 1, 2, 3, 4]
        elif self.resolution == 256:
            self.save_features = [0, 1, 2, 3, 4, 5]

        self.out_channel_multiplier = 1  #4
        # Architecture
        self.arch = D_unet_arch(
            self.ch,
            self.attention,
            out_channel_multiplier=self.out_channel_multiplier)[resolution]

        self.unconditional = kwargs["unconditional"]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)

            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []

        for index in range(len(self.arch['out_channels'])):

            if self.arch["downsample"][index]:
                self.blocks += [[
                    layers.DBlock(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(nn.AvgPool2d(2) if
                                    self.arch['downsample'][index] else None))
                ]]

            elif self.arch["upsample"][index]:
                upsample_function = (
                    functools.partial(F.interpolate,
                                      scale_factor=2,
                                      mode="nearest")  #mode=nearest is default
                    if self.arch['upsample'][index] else None)

                self.blocks += [[
                    layers.GBlock2(
                        in_channels=self.arch['in_channels'][index],
                        out_channels=self.arch['out_channels'][index],
                        which_conv=self.which_conv,
                        #which_bn=self.which_bn,
                        activation=self.activation,
                        upsample=upsample_function,
                        skip_connection=True)
                ]]

            # If attention on this block, attach it to the end
            attention_condition = index < 5
            if self.arch['attention'][self.arch['resolution'][
                    index]] and attention_condition:  #index < 5
                print('Adding attention layer in D at resolution %d' %
                      self.arch['resolution'][index])
                print("index = ", index)
                self.blocks[-1] += [
                    layers.Attention(self.arch['out_channels'][index],
                                     self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        last_layer = nn.Conv2d(self.ch * self.out_channel_multiplier,
                               1,
                               kernel_size=1)
        self.blocks.append(last_layer)
        #
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1],
                                        output_dim)

        self.linear_middle = self.which_linear(16 * self.ch, output_dim)
        # Embedding for projection discrimination
        #if not kwargs["agnostic_unet"] and not kwargs["unconditional"]:
        #    self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]+extra)
        if not kwargs["unconditional"]:
            self.embed_middle = self.which_embedding(self.n_classes,
                                                     16 * self.ch)
            self.embed = self.which_embedding(self.n_classes,
                                              self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        ###
        print("_____params______")
        for name, param in self.named_parameters():
            print(name, param.size())

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            self.optim = optim.Adam(params=self.parameters(),
                                    lr=self.lr,
                                    betas=(self.B1, self.B2),
                                    weight_decay=0,
                                    eps=self.adam_eps)
Пример #6
0
  def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    """
    utils中有这些参数的定义,通过parase和vars方法封装这些参数
    看一下模型到底是咋样
    G_ch 生成模型的信道 默认64,指的是一种模型机构的总和,64可解析为如下结构
    ch = 64
    arch[128] = {'in_channels' :  [ch * item for item in [16, 16, 8, 4, 2]],
                'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
                'upsample' : [True] * 5,
                'resolution' : [8, 16, 32, 64, 128],
                'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                                for i in range(3,8)}}
    dim_z 噪声的维度,默认为128

    """
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    ## TODO 暂时不理解这个的主要作用
    self.bottom_width = bottom_width
    # Resolution of the output
    ## 表示选择的结构
    self.resolution = resolution
    # Kernel size?
    ## TODO 这个不是外部参数导入的, 也么有用到
    self.kernel_size = G_kernel_size
    # Attention?
    ## 只是做了个中介,转手就到了self.arch中选择,最后会在attention的结构中得到解析
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    ## 默认False
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    ## https://zhuanlan.zhihu.com/p/68081406
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]

    # If using hierarchical latents, adjust z
    if self.hier:
      # Number of places z slots into
      self.num_slots = len(self.arch['in_channels']) + 1
      self.z_chunk_size = (self.dim_z // self.num_slots)
      # Recalculate latent dimensionality for even splitting into chunks
      self.dim_z = self.z_chunk_size *  self.num_slots
    else:
      self.num_slots = 1
      self.z_chunk_size = 0

    # Which convs, batchnorms, and linear layers to use
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple G
    ## *** fluid.dygraph.Embedding == nn.Embedding
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)

    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    self.linear = self.which_linear(self.dim_z // self.num_slots,
                                    self.arch['in_channels'][0] * (self.bottom_width **2))

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] else None))]]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

    # Initialize weights. Optionally skip init for testing.
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    # If this is an EMA copy, no need for an optim, so just return now
    if no_optim:
      return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
    if G_mixed_precision:
      print('Using fp16 adam in G...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)