def simple_discrete_latent_tower(self, input_image, target_image): hparams = self.hparams if self.is_predicting: batch_size = common_layers.shape_list(input_image)[0] rand = tf.random_uniform([batch_size, hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 return bits conv_size = self.tinyify([64, 32, 32, 1]) pair = tf.concat([input_image, target_image], axis=-1) posterior_enc = self.basic_conv_net(pair, conv_size, "posterior_enc") posterior_enc = tfl.flatten(posterior_enc) bits, _ = discretization.tanh_discrete_bottleneck( posterior_enc, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) return bits
def learned_discrete_tower(self, input_image, target_image): hparams = self.hparams # Encode the input frames into a prior encoding. conv_size = [64, 32, 32, 1] prior_enc = self.basic_conv_net(input_image, conv_size, "prior_enc") tower_output_shape = common_layers.shape_list(prior_enc) batch_size = tower_output_shape[0] prior_enc = tfl.flatten(prior_enc) def decode_bits(b): return common_video.encode_to_shape(b, tower_output_shape, "bits_dec") if self.is_predicting: if hparams.full_latent_tower: rand = tf.random_uniform([batch_size, hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: # Generate bit using the learned prior at inference time. bits, _ = discretization.predict_bits_with_lstm( prior_enc, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) return decode_bits(bits), 0.0 # Encode the input and target frames into posterior. x = tf.concat([input_image, target_image], axis=-1) x = self.basic_conv_net(x, conv_size, "posterior_enc") x = tfl.flatten(x) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) pred_loss = 0.0 if not hparams.full_latent_tower: # Learn the prior by matching the posterior. _, pred_loss = discretization.predict_bits_with_lstm( prior_enc, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) return decode_bits(bits), pred_loss
def simple_discrete_latent_tower(self, input_image, target_image): hparams = self.hparams if self.is_predicting: batch_size = common_layers.shape_list(input_image)[0] rand = tf.random_uniform([batch_size, hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 return bits conv_size = self.tinyify([64, 32, 32, 1]) pair = tf.concat([input_image, target_image], axis=-1) posterior_enc = self.basic_conv_net(pair, conv_size, "posterior_enc") posterior_enc = tfl.flatten(posterior_enc) bits, _ = discretization.tanh_discrete_bottleneck( posterior_enc, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) return bits
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) activation_fn = common_layers.belu if hparams.activation_fn == "relu": activation_fn = tf.nn.relu def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input(x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss
def inject_latent(self, layer, inputs, target): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) return add_bits(layer, bits), pred_loss
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input( x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss