Exemplo n.º 1
0
def logit(h, is_training=True, update_batch_stats=True, stochastic=True, seed=1234, dropout_mask=None, return_mask=False, h_before_dropout=None):
    rng = np.random.RandomState(seed)
    if h_before_dropout is None:
        h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1')
        h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a)
        h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2')
        h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a)
        h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3')
        h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a)

        h = L.max_pool(h, ksize=2, stride=2)
        if stochastic:
            h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden)

        h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4')
        h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a)
        h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5')
        h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a)
        h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6')
        h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a)

        h_before_dropout = L.max_pool(h, ksize=2, stride=2)

    # Making it possible to change or return a dropout mask
    if stochastic:
        if dropout_mask is None:
            dropout_mask = tf.cast(
                tf.greater_equal(tf.random_uniform(tf.shape(h_before_dropout), 0, 1, seed=rng.randint(123456)), 1.0 - FLAGS.keep_prob_hidden),
                tf.float32)
        else:
            dropout_mask = tf.reshape(dropout_mask, tf.shape(h_before_dropout))
        h = tf.multiply(h_before_dropout, dropout_mask)
        h = (1.0 / FLAGS.keep_prob_hidden) * h
    else:
        h = h_before_dropout
    h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7')
    h = L.lrelu(L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8')
    h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9')
    h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a)

    h = tf.reduce_mean(h, reduction_indices=[1, 2])  # Global average pooling
    h = L.fc(h, 128, 10, seed=rng.randint(123456), name='fc')

    if FLAGS.top_bn:
        h = L.bn(h, 10, is_training=is_training,
                 update_batch_stats=update_batch_stats, name='bfc')
    if return_mask:
        return h, tf.reshape(dropout_mask, [-1, 8*8*256]), h_before_dropout
    else:
        return h
Exemplo n.º 2
0
def discriminator(x,
                  y,
                  is_training=True,
                  update_batch_stats=True,
                  act_fn=L.lrelu,
                  bn=FLAGS.dis_bn,
                  reuse=True):
    with tf.variable_scope('discriminator', reuse=reuse):
        if FLAGS.method == 'cgan':
            h = L.fc(y,
                     y_dim,
                     X_dim * X_dim,
                     seed=rng.randint(123456),
                     name='fc_y')
            h = tf.reshape(h, [-1, X_dim, X_dim, 1])
            h = tf.concat((x, h), axis=3)
            h = L.conv(h, 3, 1, num_channels + 1, 32, name="conv1")
        else:
            h = L.conv(x, 3, 1, num_channels, 32, name="conv1")
        h = act_fn(h)

        # 64x64 -> 32x32
        h = L.conv(
            h,
            4,
            2,
            32,
            64,
            name="conv2",
        )
        h = L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 use_gamma=False,
                 name='bn1') if bn else h
        h = act_fn(h)

        # 32x32 -> 16x16
        h = L.conv(h, 4, 2, 64, 128, name="conv3")
        h = L.bn(h,
                 128,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 use_gamma=False,
                 name='bn2') if bn else h
        h = act_fn(h)
        h = L.conv(h, X_dim / 4, 1, 128, 1, name="conv5", padding="VALID")
        logits = tf.reshape(h, [-1, 1])
        return logits
Exemplo n.º 3
0
def logit(x,
          is_training=True,
          update_batch_stats=True,
          stochastic=True,
          seed=1234):
    x = tf.reshape(x, [x.get_shape().as_list()[0], -1])
    layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32)
    num_layers = len(layer_sizes) - 1
    rng = numpy.random.RandomState(seed)
    h = x
    for l, dim in enumerate(layer_sizes):
        inp_dim = h.get_shape()[1]
        with tf.variable_scope(str(l)):
            W = tf.get_variable(
                'W',
                shape=[inp_dim, dim],
                initializer=tf.contrib.layers.xavier_initializer(
                    uniform=False, seed=rng.randint(123456), dtype=tf.float32))
            b = tf.get_variable('b',
                                shape=[dim],
                                initializer=tf.constant_initializer(0.0))
            h = tf.nn.xw_plus_b(h, W, b)
            h = L.bn(h,
                     dim,
                     is_training=is_training,
                     update_batch_stats=update_batch_stats)

            if l < num_layers - 1:
                h = tf.nn.relu(h)
                h = gaussian_noise_layer(
                    h, stddev=FLAGS.noise_stddev, seed=rng.randint(123456)
                ) if FLAGS.noise_stddev > 0 and stochastic else h
    return h
Exemplo n.º 4
0
def logit_small(x,
                num_classes,
                is_training=True,
                update_batch_stats=True,
                stochastic=True,
                seed=1234):

    if is_training:
        scope = tf.name_scope("Training")

    else:
        scope = tf.name_scope("Testing")

    with scope:
        h = x

        rng = np.random.RandomState(seed)

        h = L.fc(h,
                 dim_in=x.shape[1],
                 dim_out=64,
                 seed=rng.randint(123456),
                 name="fc1")
        h = L.lrelu(
            L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='fc1_normalized'), FLAGS.lrelu_a)
        h = L.fc(h,
                 dim_in=64,
                 dim_out=64,
                 seed=rng.randint(123456),
                 name="fc2")
        h = L.lrelu(
            L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='fc2_normalized'), FLAGS.lrelu_a)
        h = L.fc(h,
                 dim_in=64,
                 dim_out=num_classes,
                 seed=rng.randint(123456),
                 name="fc3")
        return h
Exemplo n.º 5
0
 def __init__(self, F=None):
     from theano.tensor.nnet import sigmoid, relu
     from layers import initGain, fcLayer, convUnit, nonlinLayer, reshapeLayer, convLayer
     from layers import batchNormLayer2D as bn
     l = []
     l.append(convLayer(filterShape=(32, 1, 5, 5), stride=2))
     l.append(bn(n_out=32))
     l.append(nonlinLayer(activation=relu))
     l.append(convLayer(filterShape=(128, 32, 5, 5), stride=2))
     l.append(bn(n_out=128))
     l.append(nonlinLayer(activation=relu))
     l.append(convLayer(filterShape=(256, 128, 5, 5), stride=2))
     l.append(bn(n_out=256))
     l.append(nonlinLayer(activation=relu))
     l.append(reshapeLayer((-1, 256 * 14 * 14)))
     l.append(fcLayer(n_in=256 * 14 * 14, n_out=1, activation=sigmoid))
     self.l = l
     self.params = get_params(l)
Exemplo n.º 6
0
def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234):
    h = x

    rng = numpy.random.RandomState(seed)

    h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1')
    h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2')
    h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3')
    h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a)

    h = L.max_pool(h, ksize=2, stride=2)
    h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h

    h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4')
    h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5')
    h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6')
    h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a)

    h = L.max_pool(h, ksize=2, stride=2)
    h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h

    h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7')
    h = L.lrelu(L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8')
    h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a)
    h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9')
    h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a)

    h1 = tf.reduce_mean(h, reduction_indices=[1, 2])  # Features to be aligned
    h = L.fc(h1, 128, 10, seed=rng.randint(123456), name='fc')

    if FLAGS.top_bn:
        h = L.bn(h, 10, is_training=is_training,
                 update_batch_stats=update_batch_stats, name='bfc')

    return h, h1
Exemplo n.º 7
0
    def __init__(self, nz, F=None):
        from theano.tensor.nnet import sigmoid, relu
        from layers import initGain, fcLayer, convUnit, nonlinLayer, reshapeLayer, convLayer
        from layers import batchNormLayer2D as bn
        l = []
        l.append(fcLayer(n_in=nz, n_out=256 * 13 * 13))
        l.append(bn(n_out=256 * 13 * 13))
        l.append(nonlinLayer(activation=relu, a=0.2))
        l.append(reshapeLayer((-1, 256, 13, 13)))
        l.append(convLayer(filterShape=(128, 256, 5, 5), stride=0.5))
        l.append(bn(n_out=128))
        l.append(nonlinLayer(activation=relu, a=0.2))
        l.append(convLayer(filterShape=(64, 128, 5, 5), stride=0.5))
        l.append(bn(n_out=64))
        l.append(nonlinLayer(activation=relu, a=0.2))
        l.append(
            convLayer(filterShape=(1, 64, 4, 4),
                      stride=0.5,
                      activation=sigmoid))

        self.l = l
        self.params = get_params(l)
Exemplo n.º 8
0
  def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Number of resblocks per stage
    self.G_depth = G_depth
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    self.bottom_width = bottom_width
    # Resolution of the output
    self.resolution = resolution
    # Kernel size?
    self.kernel_size = G_kernel_size
    # Attention?
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]


    # Which convs, batchnorms, and linear layers to use
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple G
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.dim_z if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)


    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    self.linear = self.which_linear(self.dim_z + self.shared_dim, self.arch['in_channels'][0] * (self.bottom_width **2))

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['in_channels'][index] if g_index==0 else self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] and g_index == (self.G_depth-1) else None))]
                       for g_index in range(self.G_depth)]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

    # Initialize weights. Optionally skip init for testing.
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    # If this is an EMA copy, no need for an optim, so just return now
    if no_optim:
      return
    self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
    if G_mixed_precision:
      print('Using fp16 adam in G...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
Exemplo n.º 9
0
    def __init__(self,
                 G_ch=64,
                 dim_z=128,
                 bottom_width=4,
                 resolution=128,
                 G_kernel_size=3,
                 G_attn='64',
                 n_classes=1000,
                 num_G_SVs=1,
                 num_G_SV_itrs=1,
                 G_shared=True,
                 shared_dim=0,
                 hier=False,
                 cross_replica=False,
                 mybn=False,
                 G_activation=nn.ReLU(inplace=False),
                 G_lr=5e-5,
                 G_B1=0.0,
                 G_B2=0.999,
                 adam_eps=1e-8,
                 BN_eps=1e-5,
                 SN_eps=1e-12,
                 G_mixed_precision=False,
                 G_fp16=False,
                 G_init='ortho',
                 skip_init=False,
                 no_optim=False,
                 G_param='SN',
                 norm_style='bn',
                 add_blur=False,
                 add_noise=False,
                 add_style=False,
                 style_mlp=6,
                 attn_style='nl',
                 no_conditional=False,
                 sched_version='default',
                 num_epochs=500,
                 arch=None,
                 skip_z=False,
                 use_dog_cnt=False,
                 dim_dog_cnt_z=32,
                 mix_style=False,
                 **kwargs):
        super(Generator, self).__init__()
        # Channel width mulitplier
        self.ch = G_ch
        # Dimensionality of the latent space
        self.dim_z = dim_z
        # The initial spatial dimensions
        self.bottom_width = bottom_width
        # Resolution of the output
        self.resolution = resolution
        # Kernel size?
        self.kernel_size = G_kernel_size
        # Attention?
        self.attention = G_attn
        # number of classes, for use in categorical conditional generation
        self.n_classes = n_classes
        # Use shared embeddings?
        self.G_shared = G_shared
        # Dimensionality of the shared embedding? Unused if not using G_shared
        self.shared_dim = shared_dim if shared_dim > 0 else dim_z
        # Hierarchical latent space?
        self.hier = hier
        # Cross replica batchnorm?
        self.cross_replica = cross_replica
        # Use my batchnorm?
        self.mybn = mybn
        # nonlinearity for residual blocks
        self.activation = G_activation
        # Initialization style
        self.init = G_init
        # Parameterization style
        self.G_param = G_param
        # Normalization style
        self.norm_style = norm_style
        # Normalization style
        self.add_blur = add_blur
        self.add_noise = add_noise
        self.add_style = add_style
        self.skip_z = skip_z
        self.use_dog_cnt = use_dog_cnt
        self.dim_dog_cnt_z = dim_dog_cnt_z
        self.mix_style = mix_style

        # Epsilon for BatchNorm?
        self.BN_eps = BN_eps
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # fp16?
        self.fp16 = G_fp16
        # Architecture dict
        if arch is None:
            arch = f'{resolution}'
        self.arch = G_arch(self.ch, self.attention)[arch]

        # If using hierarchical latents, adjust z
        if self.hier:
            # Number of places z slots into
            self.num_slots = len(self.arch['in_channels']) + 1
            self.z_chunk_size = (self.dim_z // self.num_slots)
            # Recalculate latent dimensionality for even splitting into chunks
            self.dim_z = self.z_chunk_size * self.num_slots
        else:
            self.num_slots = 1
            self.z_chunk_size = 0

        # Which convs, batchnorms, and linear layers to use
        if self.G_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_G_SVs,
                                                num_itrs=num_G_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_G_SVs,
                                                  num_itrs=num_G_SV_itrs,
                                                  eps=self.SN_eps)
        else:
            self.which_conv = functools.partial(nn.Conv2d,
                                                kernel_size=3,
                                                padding=1)
            self.which_linear = nn.Linear

        if attn_style == 'cbam':
            self.which_attn = layers.CBAM
        else:
            self.which_attn = layers.Attention

        # We use a non-spectral-normed embedding here regardless;
        # For some reason applying SN to G's embedding seems to randomly cripple G
        self.which_embedding = nn.Embedding
        bn_linear = (functools.partial(self.which_linear, bias=False)
                     if self.G_shared else self.which_embedding)
        input_size = self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes
        if self.G_shared and use_dog_cnt:
            input_size += dim_dog_cnt_z
        self.which_bn = functools.partial(
            layers.ccbn,
            which_linear=bn_linear,
            cross_replica=self.cross_replica,
            mybn=self.mybn,
            input_size=input_size,
            norm_style=self.norm_style,
            eps=self.BN_eps,
            style_linear=self.which_linear,
            dim_z=self.dim_z,
            no_conditional=no_conditional,
            skip_z=self.skip_z,
            use_dog_cnt=use_dog_cnt,
            g_shared=G_shared,
        )

        # Prepare model
        # If not using shared embeddings, self.shared is just a passthrough
        self.shared = (self.which_embedding(n_classes, self.shared_dim)
                       if G_shared else layers.identity())

        self.dog_cnt_shared = (self.which_embedding(4, self.dim_dog_cnt_z)
                               if G_shared else layers.identity())
        # First linear layer
        self.linear = self.which_linear(
            self.dim_z // self.num_slots,
            self.arch['in_channels'][0] * (self.bottom_width**2))

        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        # while the inner loop is over a given block
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[
                layers.GBlock(
                    in_channels=self.arch['in_channels'][index],
                    out_channels=self.arch['out_channels'][index],
                    which_conv=self.which_conv,
                    which_bn=self.which_bn,
                    activation=self.activation,
                    upsample=(functools.partial(F.interpolate, scale_factor=2)
                              if self.arch['upsample'][index] else None),
                    add_blur=add_blur,
                    add_noise=add_noise,
                )
            ]]

            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in G at resolution %d' %
                      self.arch['resolution'][index])
                self.blocks[-1] += [
                    self.which_attn(self.arch['out_channels'][index],
                                    self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        # output layer: batchnorm-relu-conv.
        # Consider using a non-spectral conv here
        self.output_layer = nn.Sequential(
            layers.bn(self.arch['out_channels'][-1],
                      cross_replica=self.cross_replica,
                      mybn=self.mybn), self.activation,
            self.which_conv(self.arch['out_channels'][-1], 3))

        if self.add_style:
            # layers = [PixelNorm()]
            style_layers = []
            for i in range(style_mlp):
                style_layers.append(
                    layers.StyleLayer(self.dim_z, self.which_linear,
                                      self.activation))

            self.style = nn.Sequential(*style_layers)

        # Initialize weights. Optionally skip init for testing.
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        # If this is an EMA copy, no need for an optim, so just return now
        if no_optim:
            return
        self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
        if G_mixed_precision:
            print('Using fp16 adam in G...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps,
                                      amsgrad=kwargs['amsgrad'])
        else:
            self.optim = optim.Adam(params=self.parameters(),
                                    lr=self.lr,
                                    betas=(self.B1, self.B2),
                                    weight_decay=0,
                                    eps=self.adam_eps,
                                    amsgrad=kwargs['amsgrad'])

        # LR scheduling, left here for forward compatibility
        # self.lr_sched = {'itr' : 0}# if self.progressive else {}
        # self.j = 0

        if sched_version == 'default':
            self.lr_sched = None
        elif sched_version == 'cal_v0':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
                self.optim,
                T_max=num_epochs,
                eta_min=self.lr / 2,
                last_epoch=-1)
        elif sched_version == 'cal_v1':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
                self.optim,
                T_max=num_epochs,
                eta_min=self.lr / 4,
                last_epoch=-1)
        elif sched_version == 'cawr_v0':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optim, T_0=10, T_mult=2, eta_min=self.lr / 2)
        elif sched_version == 'cawr_v1':
            self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optim, T_0=25, T_mult=2, eta_min=self.lr / 4)
        else:
            self.lr_sched = None
Exemplo n.º 10
0
def generator(z,
              y,
              is_training=True,
              update_batch_stats=True,
              act_fn=L.lrelu,
              bn=FLAGS.gen_bn,
              reuse=True,
              dropout=FLAGS.gen_dropout):
    with tf.variable_scope('generator', reuse=reuse):
        if FLAGS.method == "cgan":
            inputs = tf.concat(axis=1, values=[z, y])
            h = L.fc(inputs,
                     Z_dim + y_dim, ((X_dim / 4)**2) * 128,
                     seed=rng.randint(123456),
                     name='fc1')
        else:
            h = L.fc(z,
                     Z_dim, ((X_dim / 4)**2) * 128,
                     seed=rng.randint(123456),
                     name='fc1')
        h = L.bn(h, ((X_dim / 4)**2) * 128,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 use_gamma=False,
                 name='bn1') if bn else h
        h = act_fn(h)
        h = tf.reshape(h, [-1, X_dim / 4, X_dim / 4, 128])

        # 16x16 -> 32x32
        h = L.deconv(h, ksize=2, stride=2, f_in=128, f_out=64, name="deconv1")
        h = L.conv(h, 5, 1, 64, 64, name="conv1")
        h = L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 use_gamma=False,
                 name='bn2') if bn else h
        h = tf.nn.dropout(h, keep_prob=0.5) if dropout else h
        h = act_fn(h)

        h = L.conv(h, 3, 1, 64, 64, name="conv2")
        h = L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 use_gamma=False,
                 name='b3') if bn else h
        h = tf.nn.dropout(h, keep_prob=0.5) if dropout else h
        h = act_fn(h)

        # 32x32 -> 64x64
        h = L.deconv(h, ksize=2, stride=2, f_in=64, f_out=32, name="deconv2")
        h = L.conv(h, 5, 1, 32, 32, name="conv3")
        h = L.bn(h,
                 32,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 use_gamma=False,
                 name='b4')
        h = tf.nn.dropout(h, keep_prob=0.5) if dropout else h
        h = act_fn(h)

        h = L.conv(h, 5, 1, 32, num_channels, name="conv4")
        h = tf.nn.tanh(h, name="output")
        return h
Exemplo n.º 11
0
def logit(x,
          num_classes=10,
          is_training=True,
          update_batch_stats=True,
          stochastic=True,
          seed=1234):

    if is_training:
        scope = tf.name_scope("Training")

    else:
        scope = tf.name_scope("Testing")

    with scope:
        h = x

        rng = np.random.RandomState(seed)

        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=3,
                   f_out=128,
                   seed=rng.randint(123456),
                   name='c1')
        h = L.lrelu(
            L.bn(h,
                 128,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b1'), FLAGS.lrelu_a)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=128,
                   f_out=128,
                   seed=rng.randint(123456),
                   name='c2')
        h = L.lrelu(
            L.bn(h,
                 128,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b2'), FLAGS.lrelu_a)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=128,
                   f_out=128,
                   seed=rng.randint(123456),
                   name='c3')
        h = L.lrelu(
            L.bn(h,
                 128,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b3'), FLAGS.lrelu_a)

        h = L.max_pool(h, ksize=2, stride=2)
        h = tf.nn.dropout(h,
                          keep_prob=FLAGS.keep_prob_hidden,
                          seed=rng.randint(123456)) if stochastic else h

        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=128,
                   f_out=256,
                   seed=rng.randint(123456),
                   name='c4')
        h = L.lrelu(
            L.bn(h,
                 256,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b4'), FLAGS.lrelu_a)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=256,
                   f_out=256,
                   seed=rng.randint(123456),
                   name='c5')
        h = L.lrelu(
            L.bn(h,
                 256,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b5'), FLAGS.lrelu_a)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=256,
                   f_out=256,
                   seed=rng.randint(123456),
                   name='c6')
        h = L.lrelu(
            L.bn(h,
                 256,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b6'), FLAGS.lrelu_a)

        h = L.max_pool(h, ksize=2, stride=2)
        h = tf.nn.dropout(h,
                          keep_prob=FLAGS.keep_prob_hidden,
                          seed=rng.randint(123456)) if stochastic else h

        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=256,
                   f_out=512,
                   seed=rng.randint(123456),
                   padding="VALID",
                   name='c7')
        h = L.lrelu(
            L.bn(h,
                 512,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b7'), FLAGS.lrelu_a)
        h = L.conv(h,
                   ksize=1,
                   stride=1,
                   f_in=512,
                   f_out=256,
                   seed=rng.randint(123456),
                   name='c8')
        h = L.lrelu(
            L.bn(h,
                 256,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b8'), FLAGS.lrelu_a)
        h = L.conv(h,
                   ksize=1,
                   stride=1,
                   f_in=256,
                   f_out=128,
                   seed=rng.randint(123456),
                   name='c9')
        h = L.lrelu(
            L.bn(h,
                 128,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='b9'), FLAGS.lrelu_a)

        h = tf.reduce_mean(h, reduction_indices=[1,
                                                 2])  # Global average pooling
        h = L.fc(h, 128, num_classes, seed=rng.randint(123456), name='fc')

        if FLAGS.top_bn:
            h = L.bn(h,
                     num_classes,
                     is_training=is_training,
                     update_batch_stats=update_batch_stats,
                     name='bfc')

        return h
Exemplo n.º 12
0
def autoencoder(x,
                zca,
                is_training=True,
                update_batch_stats=True,
                stochastic=True,
                seed=1234,
                use_zca=True):

    if is_training:
        scope = tf.name_scope("Training")

    else:
        scope = tf.name_scope("Testing")

    with scope:
        #Initial shape (-1, 32, 32, 3)
        x = x + 0.5  #Recover [0,1] range
        if use_zca:
            h = zca
        else:
            h = x
        print(h.shape)
        rng = np.random.RandomState(seed)

        #h = tf.map_fn(lambda x:transform(x),h)

        #(1) conv + relu + maxpool (-1, 16, 16, 64)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=3,
                   f_out=64,
                   seed=rng.randint(123456),
                   padding="SAME",
                   name='conv1')
        h = L.lrelu(
            L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='conv1_bn'), FLAGS.lrelu_a)
        h = L.max_pool(h, ksize=2, stride=2)

        #(2) conv + relu + maxpool (-1, 8, 8, 32)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=64,
                   f_out=32,
                   seed=rng.randint(123456),
                   padding="SAME",
                   name='conv2')
        h = L.lrelu(
            L.bn(h,
                 32,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='conv2_bn'), FLAGS.lrelu_a)
        h = L.max_pool(h, ksize=2, stride=2)

        #(3) conv + relu + maxpool (-1, 4, 4, 16)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=32,
                   f_out=16,
                   seed=rng.randint(123456),
                   padding="SAME",
                   name='conv3')
        h = L.lrelu(
            L.bn(h,
                 16,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='conv3_bn'), FLAGS.lrelu_a)
        h = L.max_pool(h, ksize=2, stride=2)

        encoded = h
        #(4) deconv + relu (-1, 8, 8, 16)
        h = L.deconv(encoded,
                     ksize=5,
                     stride=1,
                     f_in=16,
                     f_out=16,
                     seed=rng.randint(123456),
                     padding="SAME",
                     name="deconv1")
        h = L.lrelu(
            L.bn(h,
                 16,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='deconv1_bn'), FLAGS.lrelu_a)

        #(5) deconv + relu (-1, 16, 16, 32)
        h = L.deconv(h,
                     ksize=5,
                     stride=1,
                     f_in=16,
                     f_out=32,
                     padding="SAME",
                     name="deconv2")
        h = L.lrelu(
            L.bn(h,
                 32,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='deconv2_bn'), FLAGS.lrelu_a)

        #(5) deconv + relu (-1, 32, 32, 64)
        h = L.deconv(h,
                     ksize=5,
                     stride=1,
                     f_in=32,
                     f_out=64,
                     padding="SAME",
                     name="deconv3")
        h = L.lrelu(
            L.bn(h,
                 64,
                 is_training=is_training,
                 update_batch_stats=update_batch_stats,
                 name='deconv3_bn'), FLAGS.lrelu_a)

        #(7) conv + sigmoid (-1, 32, 32, 3)
        h = L.conv(h,
                   ksize=3,
                   stride=1,
                   f_in=64,
                   f_out=3,
                   seed=rng.randint(123456),
                   padding="SAME",
                   name='convfinal')
        if use_zca:
            h = L.bn(h,
                     3,
                     is_training=is_training,
                     update_batch_stats=update_batch_stats,
                     name='deconv4_bn')
        else:
            h = tf.sigmoid(h)

        num_samples = 10
        sample_og_zca = tf.reshape(
            tf.slice(zca, [0, 0, 0, 0], [num_samples, 32, 32, 3]),
            (num_samples * 32, 32, 3))
        sample_og_color = tf.reshape(
            tf.slice(x, [0, 0, 0, 0], [num_samples, 32, 32, 3]),
            (num_samples * 32, 32, 3))
        sample_rec = tf.reshape(
            tf.slice(h, [0, 0, 0, 0], [num_samples, 32, 32, 3]),
            (num_samples * 32, 32, 3))
        if use_zca:
            sample = tf.concat([sample_og_zca, sample_rec], axis=1)
            m = tf.reduce_min(sample)
            sample = (sample - m) / (tf.reduce_max(sample) - m)
        else:
            m = tf.reduce_min(sample_og_zca)
            sample_og_zca = (sample_og_zca -
                             m) / (tf.reduce_max(sample_og_zca) - m)
            sample = tf.concat([sample_og_zca, sample_rec], axis=1)
        sample = tf.concat([sample_og_color, sample], axis=1)
        sample = tf.cast(255.0 * sample, tf.uint8)

        if use_zca:
            loss = tf.reduce_mean(tf.losses.mean_squared_error(zca, h))
        else:
            loss = tf.reduce_mean(tf.losses.log_loss(x, h))

        return loss, encoded, sample
Exemplo n.º 13
0
  def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    """
    utils中有这些参数的定义,通过parase和vars方法封装这些参数
    看一下模型到底是咋样
    G_ch 生成模型的信道 默认64,指的是一种模型机构的总和,64可解析为如下结构
    ch = 64
    arch[128] = {'in_channels' :  [ch * item for item in [16, 16, 8, 4, 2]],
                'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
                'upsample' : [True] * 5,
                'resolution' : [8, 16, 32, 64, 128],
                'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                                for i in range(3,8)}}
    dim_z 噪声的维度,默认为128

    """
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    ## TODO 暂时不理解这个的主要作用
    self.bottom_width = bottom_width
    # Resolution of the output
    ## 表示选择的结构
    self.resolution = resolution
    # Kernel size?
    ## TODO 这个不是外部参数导入的, 也么有用到
    self.kernel_size = G_kernel_size
    # Attention?
    ## 只是做了个中介,转手就到了self.arch中选择,最后会在attention的结构中得到解析
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    ## 默认False
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    ## https://zhuanlan.zhihu.com/p/68081406
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]

    # If using hierarchical latents, adjust z
    if self.hier:
      # Number of places z slots into
      self.num_slots = len(self.arch['in_channels']) + 1
      self.z_chunk_size = (self.dim_z // self.num_slots)
      # Recalculate latent dimensionality for even splitting into chunks
      self.dim_z = self.z_chunk_size *  self.num_slots
    else:
      self.num_slots = 1
      self.z_chunk_size = 0

    # Which convs, batchnorms, and linear layers to use
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple G
    ## *** fluid.dygraph.Embedding == nn.Embedding
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)

    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    self.linear = self.which_linear(self.dim_z // self.num_slots,
                                    self.arch['in_channels'][0] * (self.bottom_width **2))

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] else None))]]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

    # Initialize weights. Optionally skip init for testing.
    if not skip_init:
      self.init_weights()

    # Set up optimizer
    # If this is an EMA copy, no need for an optim, so just return now
    if no_optim:
      return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
    if G_mixed_precision:
      print('Using fp16 adam in G...')
      import utils
      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)
    else:
      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                           betas=(self.B1, self.B2), weight_decay=0,
                           eps=self.adam_eps)