def logit(h, is_training=True, update_batch_stats=True, stochastic=True, seed=1234, dropout_mask=None, return_mask=False, h_before_dropout=None): rng = np.random.RandomState(seed) if h_before_dropout is None: h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) if stochastic: h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a) h_before_dropout = L.max_pool(h, ksize=2, stride=2) # Making it possible to change or return a dropout mask if stochastic: if dropout_mask is None: dropout_mask = tf.cast( tf.greater_equal(tf.random_uniform(tf.shape(h_before_dropout), 0, 1, seed=rng.randint(123456)), 1.0 - FLAGS.keep_prob_hidden), tf.float32) else: dropout_mask = tf.reshape(dropout_mask, tf.shape(h_before_dropout)) h = tf.multiply(h_before_dropout, dropout_mask) h = (1.0 / FLAGS.keep_prob_hidden) * h else: h = h_before_dropout h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7') h = L.lrelu(L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a) h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a) h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a) h = tf.reduce_mean(h, reduction_indices=[1, 2]) # Global average pooling h = L.fc(h, 128, 10, seed=rng.randint(123456), name='fc') if FLAGS.top_bn: h = L.bn(h, 10, is_training=is_training, update_batch_stats=update_batch_stats, name='bfc') if return_mask: return h, tf.reshape(dropout_mask, [-1, 8*8*256]), h_before_dropout else: return h
def discriminator(x, y, is_training=True, update_batch_stats=True, act_fn=L.lrelu, bn=FLAGS.dis_bn, reuse=True): with tf.variable_scope('discriminator', reuse=reuse): if FLAGS.method == 'cgan': h = L.fc(y, y_dim, X_dim * X_dim, seed=rng.randint(123456), name='fc_y') h = tf.reshape(h, [-1, X_dim, X_dim, 1]) h = tf.concat((x, h), axis=3) h = L.conv(h, 3, 1, num_channels + 1, 32, name="conv1") else: h = L.conv(x, 3, 1, num_channels, 32, name="conv1") h = act_fn(h) # 64x64 -> 32x32 h = L.conv( h, 4, 2, 32, 64, name="conv2", ) h = L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, use_gamma=False, name='bn1') if bn else h h = act_fn(h) # 32x32 -> 16x16 h = L.conv(h, 4, 2, 64, 128, name="conv3") h = L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, use_gamma=False, name='bn2') if bn else h h = act_fn(h) h = L.conv(h, X_dim / 4, 1, 128, 1, name="conv5", padding="VALID") logits = tf.reshape(h, [-1, 1]) return logits
def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): x = tf.reshape(x, [x.get_shape().as_list()[0], -1]) layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32) num_layers = len(layer_sizes) - 1 rng = numpy.random.RandomState(seed) h = x for l, dim in enumerate(layer_sizes): inp_dim = h.get_shape()[1] with tf.variable_scope(str(l)): W = tf.get_variable( 'W', shape=[inp_dim, dim], initializer=tf.contrib.layers.xavier_initializer( uniform=False, seed=rng.randint(123456), dtype=tf.float32)) b = tf.get_variable('b', shape=[dim], initializer=tf.constant_initializer(0.0)) h = tf.nn.xw_plus_b(h, W, b) h = L.bn(h, dim, is_training=is_training, update_batch_stats=update_batch_stats) if l < num_layers - 1: h = tf.nn.relu(h) h = gaussian_noise_layer( h, stddev=FLAGS.noise_stddev, seed=rng.randint(123456) ) if FLAGS.noise_stddev > 0 and stochastic else h return h
def logit_small(x, num_classes, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): if is_training: scope = tf.name_scope("Training") else: scope = tf.name_scope("Testing") with scope: h = x rng = np.random.RandomState(seed) h = L.fc(h, dim_in=x.shape[1], dim_out=64, seed=rng.randint(123456), name="fc1") h = L.lrelu( L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, name='fc1_normalized'), FLAGS.lrelu_a) h = L.fc(h, dim_in=64, dim_out=64, seed=rng.randint(123456), name="fc2") h = L.lrelu( L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, name='fc2_normalized'), FLAGS.lrelu_a) h = L.fc(h, dim_in=64, dim_out=num_classes, seed=rng.randint(123456), name="fc3") return h
def __init__(self, F=None): from theano.tensor.nnet import sigmoid, relu from layers import initGain, fcLayer, convUnit, nonlinLayer, reshapeLayer, convLayer from layers import batchNormLayer2D as bn l = [] l.append(convLayer(filterShape=(32, 1, 5, 5), stride=2)) l.append(bn(n_out=32)) l.append(nonlinLayer(activation=relu)) l.append(convLayer(filterShape=(128, 32, 5, 5), stride=2)) l.append(bn(n_out=128)) l.append(nonlinLayer(activation=relu)) l.append(convLayer(filterShape=(256, 128, 5, 5), stride=2)) l.append(bn(n_out=256)) l.append(nonlinLayer(activation=relu)) l.append(reshapeLayer((-1, 256 * 14 * 14))) l.append(fcLayer(n_in=256 * 14 * 14, n_out=1, activation=sigmoid)) self.l = l self.params = get_params(l)
def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): h = x rng = numpy.random.RandomState(seed) h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7') h = L.lrelu(L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a) h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8') h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a) h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9') h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a) h1 = tf.reduce_mean(h, reduction_indices=[1, 2]) # Features to be aligned h = L.fc(h1, 128, 10, seed=rng.randint(123456), name='fc') if FLAGS.top_bn: h = L.bn(h, 10, is_training=is_training, update_batch_stats=update_batch_stats, name='bfc') return h, h1
def __init__(self, nz, F=None): from theano.tensor.nnet import sigmoid, relu from layers import initGain, fcLayer, convUnit, nonlinLayer, reshapeLayer, convLayer from layers import batchNormLayer2D as bn l = [] l.append(fcLayer(n_in=nz, n_out=256 * 13 * 13)) l.append(bn(n_out=256 * 13 * 13)) l.append(nonlinLayer(activation=relu, a=0.2)) l.append(reshapeLayer((-1, 256, 13, 13))) l.append(convLayer(filterShape=(128, 256, 5, 5), stride=0.5)) l.append(bn(n_out=128)) l.append(nonlinLayer(activation=relu, a=0.2)) l.append(convLayer(filterShape=(64, 128, 5, 5), stride=0.5)) l.append(bn(n_out=64)) l.append(nonlinLayer(activation=relu, a=0.2)) l.append( convLayer(filterShape=(1, 64, 4, 4), stride=0.5, activation=sigmoid)) self.l = l self.params = get_params(l)
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128, G_kernel_size=3, G_attn='64', n_classes=1000, num_G_SVs=1, num_G_SV_itrs=1, G_shared=True, shared_dim=0, hier=False, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, G_init='ortho', skip_init=False, no_optim=False, G_param='SN', norm_style='bn', **kwargs): super(Generator, self).__init__() # Channel width mulitplier self.ch = G_ch # Number of resblocks per stage self.G_depth = G_depth # Dimensionality of the latent space self.dim_z = dim_z # The initial spatial dimensions self.bottom_width = bottom_width # Resolution of the output self.resolution = resolution # Kernel size? self.kernel_size = G_kernel_size # Attention? self.attention = G_attn # number of classes, for use in categorical conditional generation self.n_classes = n_classes # Use shared embeddings? self.G_shared = G_shared # Dimensionality of the shared embedding? Unused if not using G_shared self.shared_dim = shared_dim if shared_dim > 0 else dim_z # Hierarchical latent space? self.hier = hier # Cross replica batchnorm? self.cross_replica = cross_replica # Use my batchnorm? self.mybn = mybn # nonlinearity for residual blocks self.activation = G_activation # Initialization style self.init = G_init # Parameterization style self.G_param = G_param # Normalization style self.norm_style = norm_style # Epsilon for BatchNorm? self.BN_eps = BN_eps # Epsilon for Spectral Norm? self.SN_eps = SN_eps # fp16? self.fp16 = G_fp16 # Architecture dict self.arch = G_arch(self.ch, self.attention)[resolution] # Which convs, batchnorms, and linear layers to use if self.G_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) self.which_linear = nn.Linear # We use a non-spectral-normed embedding here regardless; # For some reason applying SN to G's embedding seems to randomly cripple G self.which_embedding = nn.Embedding bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding) self.which_bn = functools.partial(layers.ccbn, which_linear=bn_linear, cross_replica=self.cross_replica, mybn=self.mybn, input_size=(self.shared_dim + self.dim_z if self.G_shared else self.n_classes), norm_style=self.norm_style, eps=self.BN_eps) # Prepare model # If not using shared embeddings, self.shared is just a passthrough self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity()) # First linear layer self.linear = self.which_linear(self.dim_z + self.shared_dim, self.arch['in_channels'][0] * (self.bottom_width **2)) # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) # while the inner loop is over a given block self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['in_channels'][index] if g_index==0 else self.arch['out_channels'][index], which_conv=self.which_conv, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=2) if self.arch['upsample'][index] and g_index == (self.G_depth-1) else None))] for g_index in range(self.G_depth)] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) # output layer: batchnorm-relu-conv. # Consider using a non-spectral conv here self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), self.activation, self.which_conv(self.arch['out_channels'][-1], 3)) # Initialize weights. Optionally skip init for testing. if not skip_init: self.init_weights() # Set up optimizer # If this is an EMA copy, no need for an optim, so just return now if no_optim: return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps if G_mixed_precision: print('Using fp16 adam in G...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) else: self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, G_kernel_size=3, G_attn='64', n_classes=1000, num_G_SVs=1, num_G_SV_itrs=1, G_shared=True, shared_dim=0, hier=False, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, G_init='ortho', skip_init=False, no_optim=False, G_param='SN', norm_style='bn', add_blur=False, add_noise=False, add_style=False, style_mlp=6, attn_style='nl', no_conditional=False, sched_version='default', num_epochs=500, arch=None, skip_z=False, use_dog_cnt=False, dim_dog_cnt_z=32, mix_style=False, **kwargs): super(Generator, self).__init__() # Channel width mulitplier self.ch = G_ch # Dimensionality of the latent space self.dim_z = dim_z # The initial spatial dimensions self.bottom_width = bottom_width # Resolution of the output self.resolution = resolution # Kernel size? self.kernel_size = G_kernel_size # Attention? self.attention = G_attn # number of classes, for use in categorical conditional generation self.n_classes = n_classes # Use shared embeddings? self.G_shared = G_shared # Dimensionality of the shared embedding? Unused if not using G_shared self.shared_dim = shared_dim if shared_dim > 0 else dim_z # Hierarchical latent space? self.hier = hier # Cross replica batchnorm? self.cross_replica = cross_replica # Use my batchnorm? self.mybn = mybn # nonlinearity for residual blocks self.activation = G_activation # Initialization style self.init = G_init # Parameterization style self.G_param = G_param # Normalization style self.norm_style = norm_style # Normalization style self.add_blur = add_blur self.add_noise = add_noise self.add_style = add_style self.skip_z = skip_z self.use_dog_cnt = use_dog_cnt self.dim_dog_cnt_z = dim_dog_cnt_z self.mix_style = mix_style # Epsilon for BatchNorm? self.BN_eps = BN_eps # Epsilon for Spectral Norm? self.SN_eps = SN_eps # fp16? self.fp16 = G_fp16 # Architecture dict if arch is None: arch = f'{resolution}' self.arch = G_arch(self.ch, self.attention)[arch] # If using hierarchical latents, adjust z if self.hier: # Number of places z slots into self.num_slots = len(self.arch['in_channels']) + 1 self.z_chunk_size = (self.dim_z // self.num_slots) # Recalculate latent dimensionality for even splitting into chunks self.dim_z = self.z_chunk_size * self.num_slots else: self.num_slots = 1 self.z_chunk_size = 0 # Which convs, batchnorms, and linear layers to use if self.G_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) self.which_linear = nn.Linear if attn_style == 'cbam': self.which_attn = layers.CBAM else: self.which_attn = layers.Attention # We use a non-spectral-normed embedding here regardless; # For some reason applying SN to G's embedding seems to randomly cripple G self.which_embedding = nn.Embedding bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding) input_size = self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes if self.G_shared and use_dog_cnt: input_size += dim_dog_cnt_z self.which_bn = functools.partial( layers.ccbn, which_linear=bn_linear, cross_replica=self.cross_replica, mybn=self.mybn, input_size=input_size, norm_style=self.norm_style, eps=self.BN_eps, style_linear=self.which_linear, dim_z=self.dim_z, no_conditional=no_conditional, skip_z=self.skip_z, use_dog_cnt=use_dog_cnt, g_shared=G_shared, ) # Prepare model # If not using shared embeddings, self.shared is just a passthrough self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity()) self.dog_cnt_shared = (self.which_embedding(4, self.dim_dog_cnt_z) if G_shared else layers.identity()) # First linear layer self.linear = self.which_linear( self.dim_z // self.num_slots, self.arch['in_channels'][0] * (self.bottom_width**2)) # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) # while the inner loop is over a given block self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[ layers.GBlock( in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=2) if self.arch['upsample'][index] else None), add_blur=add_blur, add_noise=add_noise, ) ]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [ self.which_attn(self.arch['out_channels'][index], self.which_conv) ] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) # output layer: batchnorm-relu-conv. # Consider using a non-spectral conv here self.output_layer = nn.Sequential( layers.bn(self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), self.activation, self.which_conv(self.arch['out_channels'][-1], 3)) if self.add_style: # layers = [PixelNorm()] style_layers = [] for i in range(style_mlp): style_layers.append( layers.StyleLayer(self.dim_z, self.which_linear, self.activation)) self.style = nn.Sequential(*style_layers) # Initialize weights. Optionally skip init for testing. if not skip_init: self.init_weights() # Set up optimizer # If this is an EMA copy, no need for an optim, so just return now if no_optim: return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps if G_mixed_precision: print('Using fp16 adam in G...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps, amsgrad=kwargs['amsgrad']) else: self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps, amsgrad=kwargs['amsgrad']) # LR scheduling, left here for forward compatibility # self.lr_sched = {'itr' : 0}# if self.progressive else {} # self.j = 0 if sched_version == 'default': self.lr_sched = None elif sched_version == 'cal_v0': self.lr_sched = optim.lr_scheduler.CosineAnnealingLR( self.optim, T_max=num_epochs, eta_min=self.lr / 2, last_epoch=-1) elif sched_version == 'cal_v1': self.lr_sched = optim.lr_scheduler.CosineAnnealingLR( self.optim, T_max=num_epochs, eta_min=self.lr / 4, last_epoch=-1) elif sched_version == 'cawr_v0': self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts( self.optim, T_0=10, T_mult=2, eta_min=self.lr / 2) elif sched_version == 'cawr_v1': self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts( self.optim, T_0=25, T_mult=2, eta_min=self.lr / 4) else: self.lr_sched = None
def generator(z, y, is_training=True, update_batch_stats=True, act_fn=L.lrelu, bn=FLAGS.gen_bn, reuse=True, dropout=FLAGS.gen_dropout): with tf.variable_scope('generator', reuse=reuse): if FLAGS.method == "cgan": inputs = tf.concat(axis=1, values=[z, y]) h = L.fc(inputs, Z_dim + y_dim, ((X_dim / 4)**2) * 128, seed=rng.randint(123456), name='fc1') else: h = L.fc(z, Z_dim, ((X_dim / 4)**2) * 128, seed=rng.randint(123456), name='fc1') h = L.bn(h, ((X_dim / 4)**2) * 128, is_training=is_training, update_batch_stats=update_batch_stats, use_gamma=False, name='bn1') if bn else h h = act_fn(h) h = tf.reshape(h, [-1, X_dim / 4, X_dim / 4, 128]) # 16x16 -> 32x32 h = L.deconv(h, ksize=2, stride=2, f_in=128, f_out=64, name="deconv1") h = L.conv(h, 5, 1, 64, 64, name="conv1") h = L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, use_gamma=False, name='bn2') if bn else h h = tf.nn.dropout(h, keep_prob=0.5) if dropout else h h = act_fn(h) h = L.conv(h, 3, 1, 64, 64, name="conv2") h = L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, use_gamma=False, name='b3') if bn else h h = tf.nn.dropout(h, keep_prob=0.5) if dropout else h h = act_fn(h) # 32x32 -> 64x64 h = L.deconv(h, ksize=2, stride=2, f_in=64, f_out=32, name="deconv2") h = L.conv(h, 5, 1, 32, 32, name="conv3") h = L.bn(h, 32, is_training=is_training, update_batch_stats=update_batch_stats, use_gamma=False, name='b4') h = tf.nn.dropout(h, keep_prob=0.5) if dropout else h h = act_fn(h) h = L.conv(h, 5, 1, 32, num_channels, name="conv4") h = tf.nn.tanh(h, name="output") return h
def logit(x, num_classes=10, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): if is_training: scope = tf.name_scope("Training") else: scope = tf.name_scope("Testing") with scope: h = x rng = np.random.RandomState(seed) h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1') h = L.lrelu( L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2') h = L.lrelu( L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3') h = L.lrelu( L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4') h = L.lrelu( L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5') h = L.lrelu( L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a) h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6') h = L.lrelu( L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7') h = L.lrelu( L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a) h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8') h = L.lrelu( L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a) h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9') h = L.lrelu( L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a) h = tf.reduce_mean(h, reduction_indices=[1, 2]) # Global average pooling h = L.fc(h, 128, num_classes, seed=rng.randint(123456), name='fc') if FLAGS.top_bn: h = L.bn(h, num_classes, is_training=is_training, update_batch_stats=update_batch_stats, name='bfc') return h
def autoencoder(x, zca, is_training=True, update_batch_stats=True, stochastic=True, seed=1234, use_zca=True): if is_training: scope = tf.name_scope("Training") else: scope = tf.name_scope("Testing") with scope: #Initial shape (-1, 32, 32, 3) x = x + 0.5 #Recover [0,1] range if use_zca: h = zca else: h = x print(h.shape) rng = np.random.RandomState(seed) #h = tf.map_fn(lambda x:transform(x),h) #(1) conv + relu + maxpool (-1, 16, 16, 64) h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=64, seed=rng.randint(123456), padding="SAME", name='conv1') h = L.lrelu( L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, name='conv1_bn'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) #(2) conv + relu + maxpool (-1, 8, 8, 32) h = L.conv(h, ksize=3, stride=1, f_in=64, f_out=32, seed=rng.randint(123456), padding="SAME", name='conv2') h = L.lrelu( L.bn(h, 32, is_training=is_training, update_batch_stats=update_batch_stats, name='conv2_bn'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) #(3) conv + relu + maxpool (-1, 4, 4, 16) h = L.conv(h, ksize=3, stride=1, f_in=32, f_out=16, seed=rng.randint(123456), padding="SAME", name='conv3') h = L.lrelu( L.bn(h, 16, is_training=is_training, update_batch_stats=update_batch_stats, name='conv3_bn'), FLAGS.lrelu_a) h = L.max_pool(h, ksize=2, stride=2) encoded = h #(4) deconv + relu (-1, 8, 8, 16) h = L.deconv(encoded, ksize=5, stride=1, f_in=16, f_out=16, seed=rng.randint(123456), padding="SAME", name="deconv1") h = L.lrelu( L.bn(h, 16, is_training=is_training, update_batch_stats=update_batch_stats, name='deconv1_bn'), FLAGS.lrelu_a) #(5) deconv + relu (-1, 16, 16, 32) h = L.deconv(h, ksize=5, stride=1, f_in=16, f_out=32, padding="SAME", name="deconv2") h = L.lrelu( L.bn(h, 32, is_training=is_training, update_batch_stats=update_batch_stats, name='deconv2_bn'), FLAGS.lrelu_a) #(5) deconv + relu (-1, 32, 32, 64) h = L.deconv(h, ksize=5, stride=1, f_in=32, f_out=64, padding="SAME", name="deconv3") h = L.lrelu( L.bn(h, 64, is_training=is_training, update_batch_stats=update_batch_stats, name='deconv3_bn'), FLAGS.lrelu_a) #(7) conv + sigmoid (-1, 32, 32, 3) h = L.conv(h, ksize=3, stride=1, f_in=64, f_out=3, seed=rng.randint(123456), padding="SAME", name='convfinal') if use_zca: h = L.bn(h, 3, is_training=is_training, update_batch_stats=update_batch_stats, name='deconv4_bn') else: h = tf.sigmoid(h) num_samples = 10 sample_og_zca = tf.reshape( tf.slice(zca, [0, 0, 0, 0], [num_samples, 32, 32, 3]), (num_samples * 32, 32, 3)) sample_og_color = tf.reshape( tf.slice(x, [0, 0, 0, 0], [num_samples, 32, 32, 3]), (num_samples * 32, 32, 3)) sample_rec = tf.reshape( tf.slice(h, [0, 0, 0, 0], [num_samples, 32, 32, 3]), (num_samples * 32, 32, 3)) if use_zca: sample = tf.concat([sample_og_zca, sample_rec], axis=1) m = tf.reduce_min(sample) sample = (sample - m) / (tf.reduce_max(sample) - m) else: m = tf.reduce_min(sample_og_zca) sample_og_zca = (sample_og_zca - m) / (tf.reduce_max(sample_og_zca) - m) sample = tf.concat([sample_og_zca, sample_rec], axis=1) sample = tf.concat([sample_og_color, sample], axis=1) sample = tf.cast(255.0 * sample, tf.uint8) if use_zca: loss = tf.reduce_mean(tf.losses.mean_squared_error(zca, h)) else: loss = tf.reduce_mean(tf.losses.log_loss(x, h)) return loss, encoded, sample
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, G_kernel_size=3, G_attn='64', n_classes=1000, num_G_SVs=1, num_G_SV_itrs=1, G_shared=True, shared_dim=0, hier=False, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, G_init='ortho', skip_init=False, no_optim=False, G_param='SN', norm_style='bn', **kwargs): """ utils中有这些参数的定义,通过parase和vars方法封装这些参数 看一下模型到底是咋样 G_ch 生成模型的信道 默认64,指的是一种模型机构的总和,64可解析为如下结构 ch = 64 arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], 'upsample' : [True] * 5, 'resolution' : [8, 16, 32, 64, 128], 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) for i in range(3,8)}} dim_z 噪声的维度,默认为128 """ super(Generator, self).__init__() # Channel width mulitplier self.ch = G_ch # Dimensionality of the latent space self.dim_z = dim_z # The initial spatial dimensions ## TODO 暂时不理解这个的主要作用 self.bottom_width = bottom_width # Resolution of the output ## 表示选择的结构 self.resolution = resolution # Kernel size? ## TODO 这个不是外部参数导入的, 也么有用到 self.kernel_size = G_kernel_size # Attention? ## 只是做了个中介,转手就到了self.arch中选择,最后会在attention的结构中得到解析 self.attention = G_attn # number of classes, for use in categorical conditional generation self.n_classes = n_classes # Use shared embeddings? ## 默认False self.G_shared = G_shared # Dimensionality of the shared embedding? Unused if not using G_shared self.shared_dim = shared_dim if shared_dim > 0 else dim_z # Hierarchical latent space? self.hier = hier # Cross replica batchnorm? self.cross_replica = cross_replica # Use my batchnorm? self.mybn = mybn # nonlinearity for residual blocks self.activation = G_activation # Initialization style self.init = G_init # Parameterization style self.G_param = G_param # Normalization style self.norm_style = norm_style # Epsilon for BatchNorm? self.BN_eps = BN_eps # Epsilon for Spectral Norm? ## https://zhuanlan.zhihu.com/p/68081406 self.SN_eps = SN_eps # fp16? self.fp16 = G_fp16 # Architecture dict self.arch = G_arch(self.ch, self.attention)[resolution] # If using hierarchical latents, adjust z if self.hier: # Number of places z slots into self.num_slots = len(self.arch['in_channels']) + 1 self.z_chunk_size = (self.dim_z // self.num_slots) # Recalculate latent dimensionality for even splitting into chunks self.dim_z = self.z_chunk_size * self.num_slots else: self.num_slots = 1 self.z_chunk_size = 0 # Which convs, batchnorms, and linear layers to use if self.G_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) self.which_linear = nn.Linear # We use a non-spectral-normed embedding here regardless; # For some reason applying SN to G's embedding seems to randomly cripple G ## *** fluid.dygraph.Embedding == nn.Embedding self.which_embedding = nn.Embedding bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding) self.which_bn = functools.partial(layers.ccbn, which_linear=bn_linear, cross_replica=self.cross_replica, mybn=self.mybn, input_size=(self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes), norm_style=self.norm_style, eps=self.BN_eps) # Prepare model # If not using shared embeddings, self.shared is just a passthrough self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity()) # First linear layer self.linear = self.which_linear(self.dim_z // self.num_slots, self.arch['in_channels'][0] * (self.bottom_width **2)) # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) # while the inner loop is over a given block self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=2) if self.arch['upsample'][index] else None))]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) # output layer: batchnorm-relu-conv. # Consider using a non-spectral conv here self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), self.activation, self.which_conv(self.arch['out_channels'][-1], 3)) # Initialize weights. Optionally skip init for testing. if not skip_init: self.init_weights() # Set up optimizer # If this is an EMA copy, no need for an optim, so just return now if no_optim: return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps if G_mixed_precision: print('Using fp16 adam in G...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) else: self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)