def discriminator(inp, is_training, init=False, reuse=False, getter =None): with tf.variable_scope('discriminator_model', reuse=reuse,custom_getter=getter): counter = {} x = tf.reshape(inp, [-1, 32, 32, 3]) x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout_0') x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter) x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter) x = nn.conv2d(x, 96, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_1') x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter) x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter) x = nn.conv2d(x, 192, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_2') x = nn.conv2d(x, 192, pad='VALID', nonlinearity=leakyReLu, init=init, counters=counter) x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init) x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init) x = tf.layers.max_pooling2d(x, pool_size=6, strides=1, name='avg_pool_0') x = tf.squeeze(x, [1, 2]) intermediate_layer = x logits = nn.dense(x, 10, nonlinearity=None, init=init, counters=counter, init_scale=0.1) return logits, intermediate_layer
def classifier(inp, is_training, init=False, reuse=False, getter =None,category=125): with tf.variable_scope('discriminator_model', reuse=reuse,custom_getter=getter): counter = {} #x = tf.reshape(inp, [-1, 32, 32, 3]) x = tf.reshape(inp, [-1, 200, 30, 3]) x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout_0') x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter) # 64*200*30*96 x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter) # 64*200*30*96 #x = nn.conv2d(x, 96, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) x = nn.conv2d(x, 96, stride=[5, 2], nonlinearity=leakyReLu, init=init, counters=counter) # 64*40*15*96 x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_1') # 64*40*15*96 x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter) # 64*40*15*192 x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter) # 64*40*15*192 #x = nn.conv2d(x, 192, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) x = nn.conv2d(x, 192, stride=[5, 2], nonlinearity=leakyReLu, init=init, counters=counter)# 64*8*8*192 x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_2') # 64*8*8*192 x = nn.conv2d(x, 192, pad='VALID', nonlinearity=leakyReLu, init=init, counters=counter) # 64*6*6*192 x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init) # 64*6*6*192 x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init) # 64*6*6*192 x = tf.layers.max_pooling2d(x, pool_size=6, strides=1, name='avg_pool_0') # 64*1*1*192 x = tf.squeeze(x, [1, 2]) # 64*192 intermediate_layer = x #logits = nn.dense(x, 10, nonlinearity=None, init=init, counters=counter, init_scale=0.1) logits = nn.dense(x, category, nonlinearity=None, init=init, counters=counter, init_scale=0.1) # 64*125 print('logits:',logits) return logits, intermediate_layer
def dec_down( gs, zs_posterior, training, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_latent_scales=2): assert n_residual_blocks % 2 == 0 gs = list(gs) zs_posterior = list(zs_posterior) with model_arg_scope( init=init, dropout_p=dropout_p, activation=activation): # outputs hs = [] # hidden units ps = [] # priors zs = [] # prior samples # prepare input n_filters = gs[-1].shape.as_list()[-1] h = nn.nin(gs[-1], n_filters) for l in range(n_scales): # level module ## hidden units for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) if l < n_latent_scales: ## prior spatial_shape = h.shape.as_list()[1] n_h_channels = h.shape.as_list()[-1] ### no spatial correlations p = latent_parameters(h) ps.append(p) z_prior = latent_sample(p) zs.append(z_prior) if training: ## posterior z = zs_posterior.pop(0) else: ## prior z = z_prior for i in range(n_residual_blocks // 2): n_h_channels = h.shape.as_list()[-1] h = tf.concat([h, z], axis=-1) h = nn.nin(h, n_h_channels) h = nn.residual_block(h, gs.pop()) hs.append(h) else: for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) # prepare input to next level if l + 1 < n_scales: n_filters = gs[-1].shape.as_list()[-1] h = nn.upsample(h, n_filters) assert not gs if training: assert not zs_posterior return hs, ps, zs
def build(self, input_shape): B, H, W, C = input_shape self.normalize = normalize(name='norm') self.nin_q = nn.nin(name='q', num_units=C) self.nin_k = nn.nin(name='k', num_units=C) self.nin_v = nn.nin(name='v', num_units=C) self.nin_proj_out = nn.nin(name='proj_out', num_units=C, init_scale=0.)
def enc_up(x, c, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_filters=64, max_filters=128): with model_arg_scope(init=init, dropout_p=dropout_p, activation=activation): # outputs hs = [] # prepare input # 这一行也很奇怪, 为什么要把x和c连起来呢? # xc = tf.concat([x,c], axis = -1) xc = x h = nn.nin(xc, n_filters) for l in range(n_scales): # level module for i in range(n_residual_blocks): h = nn.residual_block(h) hs.append(h) # prepare input to next level if l + 1 < n_scales: # 似乎它这个channel一直都是128, 没有增长过. n_filters = min(2 * n_filters, max_filters) h = nn.downsample(h, n_filters) return hs
def build(self, input_shape): B, H, W, C = input_shape if self.out_ch is None: self.out_ch = C self.normalize_1 = normalize('norm1') self.normalize_2 = normalize('norm2') self.dense = nn.dense(name='temb_proj', num_units=self.out_ch, spec_norm=self.spec_norm) self.conv2d_1 = nn.conv2d(name='conv1', num_units=self.out_ch, spec_norm=self.spec_norm) self.conv2d_2 = nn.conv2d(name='conv2', num_units=self.out_ch, init_scale=0., spec_norm=self.spec_norm, use_scale=self.use_scale) if self.conv_shortcut: self.conv2d_shortcut = nn.conv2d(name='conv_shortcut', num_units=self.out_ch, spec_norm=self.spec_norm) else: self.nin_shortcut = nn.nin(name='nin_shortcut', num_units=self.out_ch, spec_norm=self.spec_norm)
def dec_up( c, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_filters=64, max_filters=128, ): with model_arg_scope(init=init, dropout_p=dropout_p, activation=activation): # outputs hs = [] # prepare input h = nn.nin(c, n_filters) for l in range(n_scales): # level module for i in range(n_residual_blocks): h = nn.residual_block(h) hs.append(h) # prepare input to next level if l + 1 < n_scales: n_filters = min(2 * n_filters, max_filters) h = nn.downsample(h, n_filters) return hs
def enc_up( x, c, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_filters=64, max_filters=128, ): with model_arg_scope(init=init, dropout_p=dropout_p, activation=activation): """c is actually not used""" # outputs hs = [] # prepare input # xc = tf.concat([x,c], axis = -1) xc = x h = nn.nin(xc, n_filters) for l in range(n_scales): # level module for i in range(n_residual_blocks): h = nn.residual_block(h) hs.append(h) # prepare input to next level if l + 1 < n_scales: n_filters = min(2 * n_filters, max_filters) h = nn.downsample(h, n_filters) return hs
def classifier( x, n_out, init = False, dropout_p = 0.5, activation = "elu"): with model_arg_scope( init = init, dropout_p = dropout_p, activation = activation): # outputs hs = [] # prepare input x_shape = x.shape.as_list()#tf.shape(x) h = tf.reshape(x, [x_shape[0], 1, 1, x_shape[1]*x_shape[2]*x_shape[3]]) h = nn.activate(h) h = nn.nin(h, 1024) h = nn.activate(h) h = nn.nin(h, n_out) h = tf.reshape(h, [x_shape[0], n_out]) return h
def cfn( x, init = False, dropout_p = 0.5, n_scales = 1, n_residual_blocks = 2, activation = "elu", n_filters = 64, max_filters = 128): with model_arg_scope( init = init, dropout_p = dropout_p, activation = activation): # outputs hs = [] # prepare input xc = x h = nn.nin(xc, n_filters) for l in range(n_scales): # level module for i in range(n_residual_blocks): h = nn.residual_block(h) hs.append(h) # prepare input to next level if l + 1 < n_scales: n_filters = min(2*n_filters, max_filters) h = nn.downsample(h, n_filters) h_shape = h.shape.as_list() h = tf.reshape(h, [h_shape[0],1,1,h_shape[1]*h_shape[2]*h_shape[3]]) h = nn.nin(h, 2*max_filters) hs.append(h) return hs
def enc_down( gs, init = False, dropout_p = 0.5, n_scales = 1, n_residual_blocks = 2, activation = "elu", n_latent_scales = 2): assert n_residual_blocks % 2 == 0 gs = list(gs) with model_arg_scope( init = init, dropout_p = dropout_p, activation = activation): # outputs hs = [] # hidden units qs = [] # posteriors zs = [] # samples from posterior # prepare input n_filters = gs[-1].shape.as_list()[-1] h = nn.nin(gs[-1], n_filters) for l in range(n_scales): # level module ## hidden units for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) if l < n_latent_scales: ## posterior parameters q = latent_parameters(h) qs.append(q) ## posterior sample z = latent_sample(q) zs.append(z) ## sample feedback for i in range(n_residual_blocks // 2): gz = tf.concat([gs.pop(), z], axis = -1) h = nn.residual_block(h, gz) hs.append(h) else: """ no need to go down any further for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) """ break # prepare input to next level if l + 1 < n_scales: n_filters = gs[-1].shape.as_list()[-1] h = nn.upsample(h, n_filters) #assert not gs # not true anymore since we break out of the loop return hs, qs, zs
def dec_up(c, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_filters=64, max_filters=128): with model_arg_scope(init=init, dropout_p=dropout_p, activation=activation): hs = [] h = nn.nin(c, n_filters) for l in range(n_scales): for i in range(n_residual_blocks): h = nn.residual_block(h) hs.append(h) if l + 1 < n_scales: n_filters = min(2 * n_filters, max_filters) h = nn.downsample(h, n_filters) return hs
def enc_down(gs, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_latent_scales=2): assert n_residual_blocks % 2 == 0 gs = list(gs) with model_arg_scope(init=init, dropout_p=dropout_p, activation=activation): hs = [] # hidden units qs = [] # posteriors zs = [] # samples from posterior n_filters = gs[-1].shape.as_list()[-1] h = nn.nin(gs[-1], n_filters) for l in range(n_scales): for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) if l < n_latent_scales: q = latent_parameters(h) # posterior parameters qs.append(q) z = latent_sample(q) # posterior sample zs.append(z) for i in range(n_residual_blocks // 2): gz = tf.concat([gs.pop(), z], axis=-1) h = nn.residual_block(h, gz) hs.append(h) else: break if l + 1 < n_scales: n_filters = gs[-1].shape.as_list()[-1] h = nn.upsample(h, n_filters) return hs, qs, zs
def model_spec(x, init=False, ema=None, dropout_p=args.dropout_p): counters = {} with scopes.arg_scope([ nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.aux_gated_resnet, nn.dense ], counters=counters, init=init, ema=ema, dropout_p=dropout_p): # ////////// up pass through pixelCNN //////// xs = nn.int_shape(x) x_pad = tf.concat(3, [ x, tf.ones(xs[:-1] + [1]) ]) # add channel of ones to distinguish image from padding later on u_list = [ nn.down_shift( nn.down_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 3])) + \ nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 1]))] # stream for up and to the left for rep in range(args.nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.aux_gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) u_list.append( nn.down_shifted_conv2d(u_list[-1], num_filters=args.nr_filters, stride=[2, 2])) ul_list.append( nn.down_right_shifted_conv2d(ul_list[-1], num_filters=args.nr_filters, stride=[2, 2])) for rep in range(args.nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.aux_gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) u_list.append( nn.down_shifted_conv2d(u_list[-1], num_filters=args.nr_filters, stride=[2, 2])) ul_list.append( nn.down_right_shifted_conv2d(ul_list[-1], num_filters=args.nr_filters, stride=[2, 2])) for rep in range(args.nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.aux_gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) # /////// down pass //////// u = u_list.pop() ul = ul_list.pop() for rep in range(args.nr_resnet): u = nn.aux_gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) ul = nn.aux_gated_resnet(ul, tf.concat(3, [u, ul_list.pop()]), conv=nn.down_right_shifted_conv2d) u = nn.down_shifted_deconv2d(u, num_filters=args.nr_filters, stride=[2, 2]) ul = nn.down_right_shifted_deconv2d(ul, num_filters=args.nr_filters, stride=[2, 2]) for rep in range(args.nr_resnet + 1): u = nn.aux_gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) ul = nn.aux_gated_resnet(ul, tf.concat(3, [u, ul_list.pop()]), conv=nn.down_right_shifted_conv2d) u = nn.down_shifted_deconv2d(u, num_filters=args.nr_filters, stride=[2, 2]) ul = nn.down_right_shifted_deconv2d(ul, num_filters=args.nr_filters, stride=[2, 2]) for rep in range(args.nr_resnet + 1): u = nn.aux_gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) ul = nn.aux_gated_resnet(ul, tf.concat(3, [u, ul_list.pop()]), conv=nn.down_right_shifted_conv2d) x_out = nn.nin(tf.nn.elu(ul), 10 * args.nr_logistic_mix) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out
def dec_down(gs, zs_posterior, training, init=False, dropout_p=0.5, n_scales=1, n_residual_blocks=2, activation="elu", n_latent_scales=2): assert n_residual_blocks % 2 == 0 gs = list(gs) zs_posterior = list(zs_posterior) with model_arg_scope(init=init, dropout_p=dropout_p, activation=activation): # outputs hs = [] # hidden units ps = [] # priors zs = [] # prior samples # prepare input n_filters = gs[-1].shape.as_list()[-1] h = nn.nin(gs[-1], n_filters) for l in range(n_scales): # level module ## hidden units for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) if l < n_latent_scales: ## prior spatial_shape = h.shape.as_list()[1] n_h_channels = h.shape.as_list()[-1] if spatial_shape == 1: ### no spatial correlations p = latent_parameters(h) ps.append(p) z_prior = latent_sample(p) zs.append(z_prior) else: ### four autoregressively modeled groups if training: z_posterior_groups = nn.split_groups(zs_posterior[0]) p_groups = [] z_groups = [] p_features = tf.space_to_depth(nn.residual_block(h), 2) for i in range(4): p_group = latent_parameters(p_features, num_filters=n_h_channels) p_groups.append(p_group) z_group = latent_sample(p_group) z_groups.append(z_group) # ar feedback sampled from if training: feedback = z_posterior_groups.pop(0) else: feedback = z_group # prepare input for next group if i + 1 < 4: p_features = nn.residual_block( p_features, feedback) if training: assert not z_posterior_groups # complete prior parameters p = nn.merge_groups(p_groups) ps.append(p) # complete prior sample z_prior = nn.merge_groups(z_groups) zs.append(z_prior) ## vae feedback sampled from if training: ## posterior z = zs_posterior.pop(0) else: ## prior z = z_prior for i in range(n_residual_blocks // 2): n_h_channels = h.shape.as_list()[-1] h = tf.concat([h, z], axis=-1) h = nn.nin(h, n_h_channels) h = nn.residual_block(h, gs.pop()) hs.append(h) else: for i in range(n_residual_blocks // 2): h = nn.residual_block(h, gs.pop()) hs.append(h) # prepare input to next level if l + 1 < n_scales: n_filters = gs[-1].shape.as_list()[-1] h = nn.upsample(h, n_filters) assert not gs if training: assert not zs_posterior return hs, ps, zs
def model_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu', attention=False, nr_attn_block=1): """ We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber of the x_out tensor describes the predictive distribution for the RGB at that position. 'h' is an optional N x K matrix of values to condition our generative model on """ counters = {} with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin], counters=counters, init=init, ema=ema, dropout_p=dropout_p): # parse resnet nonlinearity argument if resnet_nonlinearity == 'concat_elu': resnet_nonlinearity = nn.concat_elu elif resnet_nonlinearity == 'elu': resnet_nonlinearity = tf.nn.elu elif resnet_nonlinearity == 'relu': resnet_nonlinearity = tf.nn.relu else: raise ('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported') with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h): # ////////// up pass through pixelCNN //////// xs = nn.int_shape(x) background = tf.concat([ ((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) / xs[1])[None, :, None, None] + 0. * x, ((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) / xs[2])[None, None, :, None] + 0. * x, ], axis=3) # add channel of ones to distinguish image from padding later on # stream for pixels above x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3) u_list = [ nn.down_shift( nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for up and to the left ul_list = [ nn.down_shift( nn.down_shifted_conv2d( x_pad, num_filters=nr_filters, filter_size=[1, 3])) + nn.right_shift( nn.down_right_shifted_conv2d( x_pad, num_filters=nr_filters, filter_size=[2, 1])) ] for attn_rep in range(nr_attn_block): for rep in range(nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) if attention: ul = ul_list[-1] raw_content = tf.concat([x, ul, background], axis=3) key, mixin = tf.split(nn.nin( nn.gated_resnet(raw_content, conv=nn.nin), nr_filters * 2), 2, axis=3) query = nn.nin( nn.gated_resnet(tf.concat([ul, background], axis=3), conv=nn.nin), nr_filters) mixed = nn.causal_attention(key, mixin, query) ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin)) x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix) return x_out