def discriminate_pix2pix(discrim_inputs, discrim_targets, num_classes, labels=None, reuse=False, data_format='NCHW', scope_name=None): print("Pix2pix Discriminator") assert data_format == 'NCHW' size = SIZE sn = Config.sn if data_format == 'NCHW': channel_axis = 1 else: channel_axis = 3 if type(discrim_targets) is list: discrim_targets = discrim_targets[-1] output_dim = 1 with tf.variable_scope(scope_name) as scope: if reuse: scope.reuse_variables() n_layers = 3 layers = [] # 2x [batch, 3, height, width] => [batch, 3 * 2, height, width] input_fusion = tf.concat([discrim_inputs, discrim_targets], axis=channel_axis) # layer_1: [batch, 6, 192, 192] => [batch, 64, 96, 96] with tf.variable_scope("layer_1"): convolved = nchw_conv(input_fusion, size, stride=2) rectified = lrelu(convolved, 0.2) layers.append(rectified) # layer_2: [batch, 64, 96, 96] => [batch, 128, 48, 48] # layer_3: [batch, 128, 48, 48] => [batch, 256, 24, 24] # layer_4: [batch, 256, 24, 24] => [batch, 512, 23, 23] for i in range(n_layers): with tf.variable_scope("layer_%d" % (len(layers) + 1)): out_channels = size * min(2 ** (i + 1), 8) stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 convolved = nchw_conv(layers[-1], out_channels, stride=stride) normalized = batchnorm(convolved, data_format=data_format) rectified = lrelu(normalized, 0.2) layers.append(rectified) # layer_5: [batch, 512, 23, 23] => [batch, 1, 22, 22], ==> discriminator end with tf.variable_scope("layer_%d" % (len(layers) + 1)): disc = nchw_conv(rectified, out_channels=output_dim, stride=1) # classification end img = tf.reduce_mean(rectified, axis=(2, 3) if data_format == 'NCHW' else (1, 2)) logits = fully_connected(img, num_classes, sn=sn, activation_fn=None, normalizer_fn=None) return disc, logits
def generate_residual(z, text_vocab_indices, LSTM_hybrid, output_channel, num_classes, vocab_size, reuse=False, data_format='NCHW', labels=None, scope_name=None): print("Residual Generator") size = SIZE sn = False input_dims = z.get_shape().as_list() if data_format == 'NCHW': height = input_dims[2] width = input_dims[3] else: height = input_dims[1] width = input_dims[2] if data_format == 'NCHW': concat_axis = 1 else: concat_axis = 3 if normalizer_params_g is not None and normalizer_fn_g != ly.batch_norm and normalizer_fn_g != ly.layer_norm: normalizer_params_g['labels'] = labels normalizer_params_g['n_labels'] = num_classes with tf.variable_scope(scope_name) as scope: if reuse: scope.reuse_variables() num_residual_units = [3, 4, 6, 3] z_encoded = image_encoder_residual(z, num_residual_units, num_classes=num_classes, reuse=reuse, data_format=data_format, labels=labels, scope_name=scope_name) # list of hidden state input_e_dims = z_encoded[-1].get_shape().as_list() batch_size = input_e_dims[0] # z_encoded[-1].shape = [N, 512, 6, 6], text_vocab_indices.shape = [N, 15] if LSTM_hybrid: ## Add text LSTM lstm_output = encode_feat_with_text(z_encoded[-1], text_vocab_indices, input_e_dims, vocab_size) feat_encoded_final = lstm_output # [N, 512, 6, 6] else: feat_encoded_final = z_encoded[-1] channel_depth = int(input_e_dims[concat_axis] / 8.) if data_format == 'NCHW': noise_dims = [batch_size, channel_depth, int(input_e_dims[2]), int(input_e_dims[3])] else: noise_dims = [batch_size, int(input_e_dims[1]), int(input_e_dims[2]), channel_depth] noise_vec = tf.random_normal(shape=(batch_size, 256), dtype=tf.float32) noise = fully_connected(noise_vec, int(np.prod(noise_dims[1:])), sn=sn, activation_fn=activation_fn_g, # normalizer_fn=normalizer_fn_g, # normalizer_params=normalizer_params_g ) noise = tf.reshape(noise, shape=noise_dims) ## decoder layer_specs = [ (size * 8, 0.0), # decoder_5: [batch, 512 * 2, 6, 6] => [batch, 512, 12, 12] (size * 4, 0.0), # decoder_4: [batch, 512 * 2, 12, 12] => [batch, 256, 24, 24] (size * 2, 0.0), # decoder_3: [batch, 256 * 2, 24, 24] => [batch, 128, 48, 48] (size, 0.0), # decoder_2: [batch, 128 * 2, 48, 48] => [batch, 64, 96, 96] ] num_encoder_layers = len(z_encoded) for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d_0" % (skip_layer + 1)): if decoder_layer == 0: input = tf.concat([feat_encoded_final, noise], axis=concat_axis) else: input = tf.concat([z_encoded[-1], z_encoded[skip_layer]], axis=concat_axis) output = bottleneck_residual_de(input, out_channels) for uId in range(1, num_residual_units[skip_layer - 1]): with tf.variable_scope("decoder_%d_%d" % (skip_layer + 1, uId)): output = bottleneck_residual_pu(output, out_channels, False) z_encoded.append(output) # decoder_1: [batch, 64 * 2, 96, 96] => [batch, 3, 192, 192] with tf.variable_scope("decoder_1"): input = tf.concat([z_encoded[-1], z_encoded[0]], axis=concat_axis) output = nchw_deconv(input, output_channel) output = batchnorm(output, data_format=data_format) output = tf.tanh(output) z_encoded.append(output) if output.get_shape().as_list()[2] != height: raise ValueError('Current shape', output.get_shape().as_list()[2], 'not match', height) return output, noise_vec
def discriminate_mru(discrim_inputs, discrim_targets, num_classes, labels=None, reuse=False, data_format='NCHW', scope_name=None): print("MRU Discriminator") assert data_format == 'NCHW' size = SIZE num_blocks = NUM_BLOCKS resize_func = tf.image.resize_bilinear sn = Config.sn if data_format == 'NCHW': channel_axis = 1 else: channel_axis = 3 if type(discrim_targets) is list: discrim_targets = discrim_targets[-1] if data_format == 'NCHW': x_list = [] resized_ = discrim_targets x_list.append(resized_) for i in range(5): resized_ = mean_pool(resized_, data_format=data_format) x_list.append(resized_) x_list = x_list[::-1] else: raise NotImplementedError output_dim = 1 with tf.variable_scope(scope_name) as scope: if reuse: scope.reuse_variables() h0 = conv2d(x_list[-1], 8, kernel_size=7, sn=sn, stride=1, data_format=data_format, activation_fn=activation_fn_d, normalizer_fn=normalizer_fn_d, normalizer_params=normalizer_params_d, weights_initializer=weight_initializer) # Initial memory state hidden_state_shape = h0.get_shape().as_list() batch_size = hidden_state_shape[0] hidden_state_shape[0] = 1 hts_0 = [h0] for i in range(1, num_blocks): h0 = tf.tile(tf.get_variable("initial_hidden_state_%d" % i, shape=hidden_state_shape, dtype=tf.float32, initializer=tf.zeros_initializer()), [batch_size, 1, 1, 1]) hts_0.append(h0) hts_1 = mru_conv(x_list[-1], hts_0, size * 2, sn=sn, stride=2, dilate_rate=1, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_d, normalizer_fn=normalizer_fn_d, normalizer_params=normalizer_params_d, weights_initializer=weight_initializer, unit_num=1) hts_2 = mru_conv(x_list[-2], hts_1, size * 4, sn=sn, stride=2, dilate_rate=1, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_d, normalizer_fn=normalizer_fn_d, normalizer_params=normalizer_params_d, weights_initializer=weight_initializer, unit_num=2) hts_3 = mru_conv(x_list[-3], hts_2, size * 8, sn=sn, stride=2, dilate_rate=1, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_d, normalizer_fn=normalizer_fn_d, normalizer_params=normalizer_params_d, weights_initializer=weight_initializer, unit_num=3) hts_4 = mru_conv(x_list[-4], hts_3, size * 12, sn=sn, stride=2, dilate_rate=1, data_format=data_format, num_blocks=num_blocks, last_unit=True, activation_fn=activation_fn_d, normalizer_fn=normalizer_fn_d, normalizer_params=normalizer_params_d, weights_initializer=weight_initializer, unit_num=4) img = hts_4[-1] img_shape = img.get_shape().as_list() # discriminator end disc = conv2d(img, output_dim, kernel_size=1, sn=sn, stride=1, data_format=data_format, activation_fn=None, normalizer_fn=None, weights_initializer=weight_initializer) if Config.proj_d: # Projection discriminator assert labels is not None and (len(labels.get_shape()) == 1 or labels.get_shape().as_list()[-1] == 1) class_embeddings = embed_labels(labels, num_classes, img_shape[channel_axis], sn=sn) class_embeddings = tf.reshape(class_embeddings, (img_shape[0], img_shape[channel_axis], 1, 1)) # NCHW disc += tf.reduce_sum(img * class_embeddings, axis=1, keep_dims=True) logits = None else: # classification end img = tf.reduce_mean(img, axis=(2, 3) if data_format == 'NCHW' else (1, 2)) logits = fully_connected(img, num_classes, sn=sn, activation_fn=None, normalizer_fn=None) return disc, logits
def generate_mru(z, text_vocab_indices, LSTM_hybrid, output_channel, num_classes, vocab_size, reuse=False, data_format='NCHW', labels=None, scope_name=None): print("MRU Generator") size = SIZE num_blocks = NUM_BLOCKS sn = False input_dims = z.get_shape().as_list() resize_method = tf.image.ResizeMethod.AREA if data_format == 'NCHW': height = input_dims[2] width = input_dims[3] else: height = input_dims[1] width = input_dims[2] resized_z = [tf.identity(z)] for i in range(5): resized_z.append(image_resize(z, [int(height / 2 ** (i + 1)), int(width / 2 ** (i + 1))], resize_method, data_format)) resized_z = resized_z[::-1] if data_format == 'NCHW': concat_axis = 1 else: concat_axis = 3 if normalizer_params_g is not None and normalizer_fn_g != ly.batch_norm and normalizer_fn_g != ly.layer_norm: normalizer_params_g['labels'] = labels normalizer_params_g['n_labels'] = num_classes with tf.variable_scope(scope_name) as scope: if reuse: scope.reuse_variables() z_encoded = image_encoder_mru(z, num_classes=num_classes, reuse=reuse, data_format=data_format, labels=labels, scope_name=scope_name) input_e_dims = z_encoded[-1].get_shape().as_list() batch_size = input_e_dims[0] # z_encoded[-1].shape = [N, 512, 6, 6], text_vocab_indices.shape = [N, 15] if LSTM_hybrid: ## Add text LSTM lstm_output = encode_feat_with_text(z_encoded[-1], text_vocab_indices, input_e_dims, vocab_size) feat_encoded_final = lstm_output # [N, 512, 6, 6] else: feat_encoded_final = z_encoded[-1] channel_depth = int(input_e_dims[concat_axis] / 8.) if data_format == 'NCHW': noise_dims = [batch_size, channel_depth, int(input_e_dims[2] * 2), int(input_e_dims[3] * 2)] else: noise_dims = [batch_size, int(input_e_dims[1] * 2), int(input_e_dims[2] * 2), channel_depth] noise_vec = tf.random_normal(shape=(batch_size, 256), dtype=tf.float32) noise = fully_connected(noise_vec, int(np.prod(noise_dims[1:])), sn=sn, activation_fn=activation_fn_g, # normalizer_fn=normalizer_fn_g, # normalizer_params=normalizer_params_g ) noise = tf.reshape(noise, shape=noise_dims) # Initial memory state hidden_state_shape = z_encoded[-1].get_shape().as_list() hidden_state_shape[0] = 1 hts_0 = [feat_encoded_final] input_0 = tf.concat([resized_z[1], noise], axis=concat_axis) hts_1 = mru_deconv(input_0, hts_0, size * 6, sn=sn, stride=2, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_g, normalizer_fn=normalizer_fn_g, normalizer_params=normalizer_params_g, weights_initializer=weight_initializer, unit_num=0) input_1 = tf.concat([resized_z[2], z_encoded[-3]], axis=concat_axis) hts_2 = mru_deconv(input_1, hts_1, size * 4, sn=sn, stride=2, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_g, normalizer_fn=normalizer_fn_g, normalizer_params=normalizer_params_g, weights_initializer=weight_initializer, unit_num=2) input_2 = tf.concat([resized_z[3], z_encoded[-4]], axis=concat_axis) hts_3 = mru_deconv(input_2, hts_2, size * 2, sn=sn, stride=2, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_g, normalizer_fn=normalizer_fn_g, normalizer_params=normalizer_params_g, weights_initializer=weight_initializer, unit_num=4) input_3 = tf.concat([resized_z[4], z_encoded[-5]], axis=concat_axis) hts_4 = mru_deconv(input_3, hts_3, size * 2, sn=sn, stride=2, data_format=data_format, num_blocks=num_blocks, last_unit=False, activation_fn=activation_fn_g, normalizer_fn=normalizer_fn_g, normalizer_params=normalizer_params_g, weights_initializer=weight_initializer, unit_num=6) hts_5 = mru_deconv(resized_z[5], hts_4, size * 1, sn=sn, stride=2, data_format=data_format, num_blocks=num_blocks, last_unit=True, activation_fn=activation_fn_g, normalizer_fn=normalizer_fn_g, normalizer_params=normalizer_params_g, weights_initializer=weight_initializer, unit_num=8) out = conv2d(hts_5[-1], output_channel, 7, sn=sn, stride=1, data_format=data_format, normalizer_fn=None, activation_fn=tf.nn.tanh, weights_initializer=weight_initializer) if out.get_shape().as_list()[2] != height: raise ValueError('Current shape', out.get_shape().as_list()[2], 'not match', height) return out, noise_vec