def van_enc_2d(x, first_depth, reuse=False): """The higher level structure encoder for the VAN. The high level structure is a vector instead of an image. Args: x: The higher level structure to encode. first_depth: The depth of the first layer. Depth is increased in subsequent layers. reuse: To reuse in variable scope or not. Returns: The encoded image. """ with tf.variable_scope('van_enc', reuse=reuse): a = 4 # depends on the inputs size b = 4 # a, b = 4,4 enc = tf.nn.relu(x) enc = tf.layers.dense(enc, first_depth * a * b, tf.nn.relu) enc = contrib.layers().layer_norm(enc) enc = tf.reshape(enc, [-1, a, b, first_depth]) enc = tf.layers.conv2d_transpose(enc, first_depth, 3, padding='same', activation=tf.nn.relu, strides=1) enc = contrib.layers().layer_norm(enc) enc = tf.layers.conv2d_transpose(enc, first_depth * 2, 3, padding='same', activation=tf.nn.relu, strides=2) van_higher_level_2 = tf.reshape(enc, [-1, a * 2 * b * 2 * first_depth * 2]) enc = tf.layers.conv2d_transpose(enc, first_depth * 2, 3, padding='same', activation=tf.nn.relu, strides=1) enc = contrib.layers().layer_norm(enc) enc = tf.layers.conv2d_transpose(enc, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) van_higher_level_4 = tf.reshape(enc, [-1, a * 2 * b * 2 * first_depth * 4]) van_higher_level = tf.concat( [x, van_higher_level_2, van_higher_level_4], 1) return enc, van_higher_level
def call_controller(self, input_value, read_values, prev_state, batch_size): """Make a call to the neural stack controller. See Section 3.1 of Grefenstette et al., 2015. Args: input_value: The input to the neural stack cell should be a tf.float32 tensor with shape [batch_size, 1, embedding_size] read_values: The values of the read heads at the previous timestep. prev_state: The hidden state from the previous time step. batch_size: The size of the current batch of input values. Returns: A tuple of outputs and the new NeuralStackControllerInterface. """ with tf.name_scope("controller"): # Concatenate the current input value with the read values from the # previous timestep before feeding them into the controller. controller_inputs = tf.concat([ contrib.layers().flatten(input_value), contrib.layers().flatten(read_values), ], axis=1) rnn_input = tf.tanh( tf.nn.bias_add(tf.matmul(controller_inputs, self._input_proj), self._input_bias)) (rnn_output, state) = self.rnn(rnn_input, prev_state) push_strengths = tf.sigmoid( tf.nn.bias_add(tf.matmul(rnn_output, self._push_proj), self._push_bias)) pop_strengths = tf.sigmoid( tf.nn.bias_add(tf.matmul(rnn_output, self._pop_proj), self._pop_bias)) write_values = tf.tanh( tf.nn.bias_add(tf.matmul(rnn_output, self._value_proj), self._value_bias)) outputs = tf.tanh( tf.nn.bias_add(tf.matmul(rnn_output, self._output_proj), self._output_bias)) # Reshape all the outputs according to the shapes specified by # get_controller_shape() projected_outputs = [ push_strengths, pop_strengths, write_values, outputs, state ] next_state = [ tf.reshape(output, shape=output_shape) for output, output_shape in zip( projected_outputs, self.get_controller_shape(batch_size)) ] return NeuralStackControllerInterface(*next_state)
def image_embedding(images, model_fn=resnet_v1_152, trainable=True, is_training=True, weight_decay=0.0001, batch_norm_decay=0.997, batch_norm_epsilon=1e-5, batch_norm_scale=True, add_summaries=False, reuse=False): """Extract image features from pretrained resnet model.""" is_resnet_training = trainable and is_training batch_norm_params = { "is_training": is_resnet_training, "trainable": trainable, "decay": batch_norm_decay, "epsilon": batch_norm_epsilon, "scale": batch_norm_scale, } if trainable: weights_regularizer = contrib.layers().l2_regularizer(weight_decay) else: weights_regularizer = None with tf.variable_scope(model_fn.__name__, [images], reuse=reuse) as scope: with slim.arg_scope([slim.conv2d], weights_regularizer=weights_regularizer, trainable=trainable): with slim.arg_scope( [slim.conv2d], weights_initializer=slim.variance_scaling_initializer(), activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): with slim.arg_scope([slim.batch_norm], is_training=is_resnet_training, trainable=trainable): with slim.arg_scope([slim.max_pool2d], padding="SAME"): net, end_points = model_fn( images, num_classes=None, global_pool=False, is_training=is_resnet_training, reuse=reuse, scope=scope) if add_summaries: for v in end_points.values(): contrib.layers().summaries.summarize_activation(v) return net
def transformer_revnet_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ def f(x, side_input): """f(x) for reversible layer, self-attention layer.""" encoder_self_attention_bias = side_input[0] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y x1, x2 = tf.split(encoder_input, 2, axis=-1) with tf.variable_scope(name): y1, y2 = contrib.layers().rev_block( x1, x2, f, g, num_layers=hparams.num_hidden_layers, f_side_input=[encoder_self_attention_bias], is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN) y = tf.concat([y1, y2], axis=-1) return common_layers.layer_preprocess(y, hparams)
def fc_layer(x, num_out, dropout_rate, name="fc"): with tf.variable_scope(name): out = x out = tf.layers.dense(out, num_out) out = contrib.layers().layer_norm(out) out = tf.nn.relu(out) out = tf.layers.dropout(out, dropout_rate) return out
def unit(x1, x2, block_num, depth, num_layers, dim='2d', bottleneck=True, first_batch_norm=True, stride=1, training=True): """Implements bottleneck RevNet unit from authors' RevNet architecture. Args: x1: [N, H, W, C] tensor of network activations. x2: [N, H, W, C] tensor of network activations. block_num: integer ID of block depth: First depth in bottleneck residual unit. num_layers: Number of layers in the RevNet block. dim: '2d' if 2-dimensional, '3d' if 3-dimensional. bottleneck: Should a bottleneck layer be used. first_batch_norm: Whether to keep the first batch norm layer or not. Typically used in the first RevNet block. stride: Stride for the residual function. training: True for train phase, False for eval phase. Returns: Two [N, H, W, C] output activation tensors. """ scope_name = 'unit_%d' % block_num if bottleneck: depth1 = depth depth2 = depth * 4 else: depth1 = depth2 = depth residual = wrapped_partial(f, depth1=depth1, depth2=depth2, dim=dim, training=training, bottleneck=bottleneck) with tf.variable_scope(scope_name): downsample = downsample_bottleneck if bottleneck else downsample_residual # Manual implementation of downsampling with tf.variable_scope('downsampling'): with tf.variable_scope('x1'): hx1 = downsample(x1, depth2, dim=dim, stride=stride) fx2 = residual(x2, stride=stride, first_batch_norm=first_batch_norm) x1 = hx1 + fx2 with tf.variable_scope('x2'): hx2 = downsample(x2, depth2, dim=dim, stride=stride) fx1 = residual(x1) x2 = hx2 + fx1 # Full block using memory-efficient rev_block implementation. with tf.variable_scope('full_block'): x1, x2 = contrib.layers().rev_block( x1, x2, residual, residual, num_layers=num_layers) return x1, x2
def analogy_computation_2d(f_first_enc, f_first_frame, f_current_enc, first_depth): """Implements the deep analogy computation.""" with tf.variable_scope('analogy_computation'): frame_enc_diff = f_first_frame - f_first_enc frame_enc_diff_enc = tf.layers.conv2d( frame_enc_diff, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) f_current_enc_enc = tf.layers.conv2d( f_current_enc, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) analogy = tf.concat([frame_enc_diff_enc, f_current_enc_enc], 3) analogy = tf.layers.conv2d( analogy, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) analogy = contrib.layers().layer_norm(analogy) analogy = tf.layers.conv2d( analogy, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) return tf.layers.conv2d( analogy, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1)
from __future__ import division from __future__ import print_function from tensor2tensor.layers import common_layers from tensor2tensor.layers import common_video from tensor2tensor.layers import discretization from tensor2tensor.models.video import base from tensor2tensor.models.video import base_vae from tensor2tensor.utils import contrib from tensor2tensor.utils import registry import tensorflow.compat.v1 as tf tfl = tf.layers tfcl = contrib.layers() @registry.register_model class NextFrameSv2p(base.NextFrameBase, base_vae.NextFrameBaseVae): """Stochastic Variational Video Prediction From Basic Model!""" @property def is_recurrent_model(self): return True def tinyify(self, array): return common_video.tinyify(array, self.hparams.tiny_mode, self.hparams.small_mode) def bottom_part_tower(self, input_image,
def encoder(self, inputs, n_layers=3): """Convnet that encodes inputs into mean and std of a gaussian. Args: inputs: 5-D Tensor, shape (batch_size, num_frames, width, height, channels) n_layers: Number of layers. Returns: z_mu: Mean of the latent gaussians. z_log_var: log(var) of the latent gaussians. Raises: ValueError: If inputs is not a 5-D tensor or not float32. """ latent_dims = self.hparams.z_dim shape_as_list = inputs.shape.as_list() if len(shape_as_list) != 5: raise ValueError("Expected inputs to be a 5-D, got %d" % len(shape_as_list)) if inputs.dtype != tf.float32: raise ValueError("Expected dtype tf.float32, got %s" % inputs.dtype) # Flatten (N,T,W,H,C) into (NT,W,H,C) batch_size, _ = shape_as_list[:2] inputs = tf.reshape(inputs, [-1] + list(inputs.shape)[2:]) n_filters = 64 rectified = None # Applies 3 layer conv-net with padding, instance normalization # and leaky relu as per the encoder in # https://github.com/alexlee-gk/video_prediction padding = [[0, 0], [1, 1], [1, 1], [0, 0]] for i in range(n_layers): with tf.variable_scope("layer_%d" % (i + 1)): n_filters *= 2**i if i: padded = tf.pad(rectified, padding) else: padded = tf.pad(inputs, padding) convolved = tf.layers.conv2d(padded, filters=n_filters, kernel_size=4, strides=2, padding="VALID") normalized = contrib.layers().instance_norm(convolved) rectified = tf.nn.leaky_relu(normalized, alpha=0.2) # Mean pooling across all spatial dimensions. pooled = tf.nn.avg_pool(rectified, [1] + rectified.shape[1:3].as_list() + [1], strides=[1, 1, 1, 1], padding="VALID") squeezed = tf.squeeze(pooled, [1, 2]) # Down-project and output the mean and log of the standard deviation of # the latents. with tf.variable_scope("z_mu"): z_mu = tf.layers.dense(squeezed, latent_dims) with tf.variable_scope("z_log_sigma_sq"): z_log_var = tf.layers.dense(squeezed, latent_dims) z_log_var = tf.clip_by_value(z_log_var, -10, 10) # Reshape to (batch_size X num_frames X latent_dims) z_mu = tf.reshape(z_mu, (batch_size, -1, latent_dims)) z_log_var = tf.reshape(z_log_var, (batch_size, -1, latent_dims)) return z_mu, z_log_var
def van_image_enc_2d(x, first_depth, reuse=False, hparams=None): """The image encoder for the VAN. Similar architecture as Ruben's paper (http://proceedings.mlr.press/v70/villegas17a/villegas17a.pdf). Args: x: The image to encode. first_depth: The depth of the first layer. Depth is increased in subsequent layers. reuse: To reuse in variable scope or not. hparams: The python hparams. Returns: The encoded image. """ with tf.variable_scope('van_image_enc', reuse=reuse): enc_history = [x] enc = tf.layers.conv2d(x, first_depth, 3, padding='same', activation=tf.nn.relu, strides=1) enc = contrib.layers().layer_norm(enc) enc = tf.layers.conv2d(enc, first_depth, 3, padding='same', activation=tf.nn.relu, strides=1) enc = tf.nn.max_pool(enc, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') enc = tf.nn.dropout(enc, hparams.van_keep_prob) enc = contrib.layers().layer_norm(enc) enc_history.append(enc) enc = tf.layers.conv2d(enc, first_depth * 2, 3, padding='same', activation=tf.nn.relu, strides=1) enc = tf.layers.conv2d(enc, first_depth * 2, 3, padding='same', activation=tf.nn.relu, strides=1) enc = tf.nn.max_pool(enc, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') enc = tf.nn.dropout(enc, hparams.van_keep_prob) enc = contrib.layers().layer_norm(enc) enc_history.append(enc) enc = tf.layers.conv2d(enc, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) enc = tf.layers.conv2d(enc, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) enc = tf.layers.conv2d(enc, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) enc = tf.nn.max_pool(enc, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') return enc, enc_history
def predictor(enc_flat, action, lstm_states, pred_depth, reuse=False, scope_prefix='', hparams=None): """LSTM predictor network.""" with tf.variable_scope(scope_prefix + 'predict', reuse=reuse): enc_final_size = enc_flat.get_shape().as_list()[1] action_size = action.get_shape().as_list()[1] initial_size = (enc_final_size + action_size) batch_size = tf.shape(enc_flat)[0] init_stddev = 1e-2 pre_pred = tf.concat([enc_flat, action], 1) pre_pred = tf.layers.dense( pre_pred, initial_size, kernel_initializer=tf.truncated_normal_initializer( stddev=init_stddev)) # This is only needed or the GAN version. if hparams.pred_noise_std > 0: # Add the noise like this so a pretrained model can be used. pred_noise = tf.random_normal(shape=[batch_size, 100], stddev=hparams.pred_noise_std) pre_pred += tf.layers.dense( pred_noise, initial_size, kernel_initializer=tf.truncated_normal_initializer( stddev=init_stddev), name='noise_dense') pre_pred = tf.nn.relu(pre_pred) if lstm_states[pred_depth - 2] is None: back_connect = tf.tile( tf.get_variable('back_connect_init', shape=[1, initial_size * 2], initializer=tf.truncated_normal_initializer( stddev=init_stddev)), (batch_size, 1)) else: back_connect = lstm_states[pred_depth - 2] lstm_init_stddev = 1e-4 part_pred, lstm_states[0] = common_video.lstm_cell( tf.concat([pre_pred, back_connect], 1), lstm_states[0], initial_size, use_peepholes=True, initializer=tf.truncated_normal_initializer( stddev=lstm_init_stddev), num_proj=initial_size) part_pred = contrib.layers().layer_norm(part_pred) pred = part_pred for pred_layer_num in range(1, pred_depth, 2): part_pred, lstm_states[pred_layer_num] = common_video.lstm_cell( pred, lstm_states[pred_layer_num], initial_size, use_peepholes=True, initializer=tf.truncated_normal_initializer( stddev=lstm_init_stddev), num_proj=initial_size) pred += part_pred part_pred, lstm_states[ pred_layer_num + 1] = common_video.lstm_cell( tf.concat([pred, pre_pred], 1), lstm_states[pred_layer_num + 1], initial_size, use_peepholes=True, initializer=tf.truncated_normal_initializer( stddev=lstm_init_stddev), num_proj=initial_size) part_pred = contrib.layers().layer_norm(part_pred) pred += part_pred pred = tf.layers.dense( pred, enc_final_size, kernel_initializer=tf.truncated_normal_initializer( stddev=init_stddev)) if hparams.enc_pred_use_l2norm: pred = tf.nn.l2_normalize(pred, 1) return pred
def van_dec_2d(x, skip_connections, output_shape, first_depth, hparams=None): """The VAN decoder. Args: x: The analogy information to decode. skip_connections: The encoder layers which can be used as skip connections. output_shape: The shape of the desired output image. first_depth: The depth of the first layer of the van image encoder. hparams: The python hparams. Returns: The decoded image prediction. """ with tf.variable_scope('van_dec'): dec = tf.layers.conv2d_transpose(x, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=2) dec = tf.nn.dropout(dec, hparams.van_keep_prob) dec = contrib.layers().layer_norm(dec) dec = tf.layers.conv2d_transpose(dec, first_depth * 4, 3, padding='same', activation=tf.nn.relu, strides=1) dec = tf.nn.dropout(dec, hparams.van_keep_prob) dec = tf.layers.conv2d_transpose(dec, first_depth * 2, 3, padding='same', activation=tf.nn.relu, strides=1) dec = tf.nn.dropout(dec, hparams.van_keep_prob) dec = contrib.layers().layer_norm(dec) dec = tf.layers.conv2d_transpose(dec, first_depth * 2, 3, padding='same', activation=tf.nn.relu, strides=2) dec = tf.nn.dropout(dec, hparams.van_keep_prob) dec = tf.layers.conv2d_transpose(dec, first_depth, 3, padding='same', activation=tf.nn.relu, strides=1) dec = tf.nn.dropout(dec, hparams.van_keep_prob) dec = contrib.layers().layer_norm(dec) dec = tf.layers.conv2d_transpose(dec, output_shape[3] + 1, 3, padding='same', activation=tf.nn.relu, strides=2) dec = tf.nn.dropout(dec, hparams.van_keep_prob) out_mask = tf.layers.conv2d_transpose(dec, output_shape[3] + 1, 3, strides=1, padding='same', activation=None) mask = tf.nn.sigmoid(out_mask[:, :, :, 3:4]) out = out_mask[:, :, :, :3] return out * mask + skip_connections[0] * (1 - mask)
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None): """Minimize loss.""" loss = weight_decay_and_noise(loss, hparams, learning_rate) loss = tf.identity(loss, name="total_loss") if variables is None: variables = tf.trainable_variables() # Print trainable variables. log_variable_sizes(variables, verbose=hparams.summarize_vars) # Print non-trainable variables. non_trainable_variables = list(set(tf.global_variables()) - set(variables)) log_variable_sizes(non_trainable_variables, tag="Non-trainable variables", verbose=hparams.summarize_vars) if hparams.summarize_vars: summarize_variables(variables) # Summarize non-trainable variables as well summarize_variables(non_trainable_variables, tag="Non-trainable variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] log_variable_sizes(diet_vars, "Diet Variables", verbose=hparams.summarize_vars) opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu) if use_tpu: opt = contrib.tpu().CrossShardOptimizer(opt) if getattr(hparams, "gpu_automatic_mixed_precision", False): if use_tpu: raise RuntimeError( "GPU auto mixed precision cannot be used with TPU") elif _mixed_precision_is_enabled(hparams): raise RuntimeError( "GPU auto mixed precision cannot be used with manual mixed precision" ) else: setattr(opt, "_use_locking", "True") setattr(opt, "_name", "ConditionalOptimizer") opt = tf.train.experimental.enable_mixed_precision_graph_rewrite( opt) opt_summaries = [] if common_layers.should_generate_summaries(): tf.summary.scalar("learning_rate", learning_rate) opt_summaries.append("loss") if hparams.summarize_grads: tf.logging.info("Summarizing gradients") opt_summaries.extend( ["gradients", "gradient_norm", "global_gradient_norm"]) if hparams.clip_grad_norm: tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm) if hparams.grad_noise_scale: tf.logging.info("Adding noise to gradients, noise scale: %0.5f", hparams.grad_noise_scale) train_op = contrib.layers().optimize_loss( name="training", loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=learning_rate, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True, variables=variables) return train_op
@registry.register_optimizer def adafactor(learning_rate, hparams): return adafactor_lib.adafactor_optimizer_from_hparams( hparams, learning_rate) def _register_base_optimizer(name, opt): key = misc_utils.camelcase_to_snakecase(name) if key in registry.Registries.optimizers: return registry.register_optimizer(key)( lambda learning_rate, hparams: opt(learning_rate)) for _name, _opt in contrib.layers().OPTIMIZER_CLS_NAMES.items(): _register_base_optimizer(_name, _opt) class ConditionalOptimizer(tf.compat.v1.train.Optimizer): """Conditional optimizer.""" def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disable=super-init-not-called tf.logging.info("Using optimizer %s", optimizer_name) mlperf_log.transformer_print(key=mlperf_log.OPT_NAME, value=optimizer_name, hparams=hparams) mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=hparams.optimizer_adam_beta1, hparams=hparams) mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
def conv_latent_tower(images, time_axis, latent_channels=1, min_logvar=-5, is_training=False, random_latent=False, tiny_mode=False, small_mode=False): """Builds convolutional latent tower for stochastic model. At training time this tower generates a latent distribution (mean and std) conditioned on the entire video. This latent variable will be fed to the main tower as an extra variable to be used for future frames prediction. At inference time, the tower is disabled and only returns latents sampled from N(0,1). If the multi_latent flag is on, a different latent for every timestep would be generated. Args: images: tensor of ground truth image sequences time_axis: the time axis in images tensor latent_channels: number of latent channels min_logvar: minimum value for log_var is_training: whether or not it is training mode random_latent: whether or not generate random latents tiny_mode: whether or not it is tiny_mode. tiny_mode sets the number of conv channels to 1 at each layer. useful for testing the integration tests. small_mode: whether or not it is small_mode. small mode is the same model with less conv and lstm layers and also lower number of channels. suitable for videos with less complexity and testing. Returns: latent_mean: predicted latent mean latent_logvar: predicted latent log variance """ conv_size = tinyify([32, 64, 64], tiny_mode, small_mode) with tf.variable_scope("latent", reuse=tf.AUTO_REUSE): images = tf.to_float(images) images = tf.unstack(images, axis=time_axis) images = tf.concat(images, axis=3) x = images x = common_layers.make_even_size(x) x = tfl.conv2d(x, conv_size[0], [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="latent_conv1") x = contrib.layers().layer_norm(x) if not small_mode: x = tfl.conv2d(x, conv_size[1], [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="latent_conv2") x = contrib.layers().layer_norm(x) x = tfl.conv2d(x, conv_size[2], [3, 3], strides=(1, 1), padding="SAME", activation=tf.nn.relu, name="latent_conv3") x = contrib.layers().layer_norm(x) nc = latent_channels mean = tfl.conv2d(x, nc, [3, 3], strides=(2, 2), padding="SAME", activation=None, name="latent_mean") logv = tfl.conv2d(x, nc, [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="latent_std") logvar = logv + min_logvar # No latent tower at inference time, just standard gaussian. if not is_training: return tf.zeros_like(mean), tf.zeros_like(logvar) # No latent in the first phase ret_mean, ret_logvar = tf.cond( random_latent, lambda: (tf.zeros_like(mean), tf.zeros_like(logvar)), lambda: (mean, logvar)) return ret_mean, ret_logvar