def inject_latent(self, layer, features, filters): """Inject a deterministic latent based on the target frame.""" del filters hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) if hparams.mode == tf.estimator.ModeKeys.PREDICT: layer_shape = common_layers.shape_list(layer) if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) else: rand = tf.random_uniform(layer_shape[:-3] + [1, 1, hparams.bottleneck_bits]) d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 z = tf.layers.dense(d, final_filters, name="unbottleneck") return layer + z, 0.0 # Embed. frames = tf.concat([features["cur_target_frame"], features["inputs"]], axis=-1) x = tf.layers.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) x = tf.tanh( tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if hparams.mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 d *= noise z = tf.layers.dense(d, final_filters, name="unbottleneck") return layer + z, 0.0
def discriminate(x): """Run a dioscriminator depending on the hparams.""" if hparams.discriminator == "default": return common_layers.deep_discriminator( x, hparams.discriminator_batchnorm, is_training) elif hparams.discriminator == "patched": return common_layers.patch_discriminator(x) elif hparams.discriminator == "single": return common_layers.single_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) elif hparams.discriminator == "double": return common_layers.double_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) else: raise Exception("Unknown discriminator %s" % hparams.discriminator)
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) activation_fn = common_layers.belu if hparams.activation_fn == "relu": activation_fn = tf.nn.relu def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input(x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss
def inject_latent(self, layer, features, filters): """Inject a deterministic latent based on the target frame.""" del filters hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) batch_size = layer_shape[0] state_size = hparams.latent_predictor_state_size lstm_cell = tf.contrib.rnn.LSTMCell(state_size) discrete_predict = tf.layers.Dense(256, name="discrete_predict") discrete_embed = tf.layers.Dense(state_size, name="discrete_embed") def add_d(layer, d): z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add") layer += z_add return layer if self.is_predicting: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) else: layer_pred = tf.reshape( layer, [batch_size, prod(layer_shape[1:])]) prediction = tf.layers.dense(layer_pred, state_size, name="istate") c_state = tf.layers.dense(layer_pred, state_size, name="cstate") m_state = tf.layers.dense(layer_pred, state_size, name="mstate") state = (c_state, m_state) outputs = [] for i in range(hparams.bottleneck_bits // 8): output, state = lstm_cell(prediction, state) discrete_logits = discrete_predict(output) discrete_samples = common_layers.sample_with_temperature( discrete_logits, hparams.latent_predictor_temperature) outputs.append(tf.expand_dims(discrete_samples, axis=1)) prediction = discrete_embed( tf.one_hot(discrete_samples, 256)) outputs = tf.concat(outputs, axis=1) outputs = discretization.int_to_bit(outputs, 8) rand = tf.reshape(outputs, [batch_size, 1, 1, hparams.bottleneck_bits]) d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 return add_d(layer, d), 0.0 # Embed. frames = tf.concat([features["cur_target_frame"], features["inputs"]], axis=-1) x = tf.layers.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck") x0 = tf.tanh(x) d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 - x0) pred_loss = 0.0 if not hparams.full_latent_tower: d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0), [batch_size, hparams.bottleneck_bits // 8, 8]) d_int = discretization.bit_to_int(d_pred, 8) tf.summary.histogram("d_int", tf.reshape(d_int, [-1])) d_hot = tf.one_hot(d_int, 256, axis=-1) d_pred = discrete_embed(d_hot) layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])]) prediction0 = tf.layers.dense(layer_pred, state_size, name="istate") c_state = tf.layers.dense(layer_pred, state_size, name="cstate") m_state = tf.layers.dense(layer_pred, state_size, name="mstate") pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred], axis=1) state = (c_state, m_state) outputs = [] for i in range(hparams.bottleneck_bits // 8): output, state = lstm_cell(pred[:, i, :], state) outputs.append(tf.expand_dims(output, axis=1)) outputs = tf.concat(outputs, axis=1) d_int_pred = discrete_predict(outputs) pred_loss = tf.losses.sparse_softmax_cross_entropy( logits=d_int_pred, labels=d_int) pred_loss = tf.reduce_mean(pred_loss) if hparams.mode == tf.estimator.ModeKeys.TRAIN: x += tf.truncated_normal(common_layers.shape_list(x), mean=0.0, stddev=0.2) x = tf.tanh(x) noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 x *= noise d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps) d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x) return add_d(layer, d), pred_loss
def inject_latent(self, layer, inputs, target): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) return add_bits(layer, bits), pred_loss
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input( x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss