def compress_self_attention_layer(x, hparams, name=None): """Attend function.""" with tf.variable_scope(name, default_name="compress_self_attention"): x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def compress_self_attention_layer(x, hparams, name=None): """Attend function.""" with tf.variable_scope(name, default_name="compress_self_attention"): x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def attend(x, source, hparams, name): """Attend function.""" with tf.variable_scope(name): # x = tf.squeeze(x, axis=2) x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def attend(x, source, hparams, name): """Attend function.""" with tf.variable_scope(name): # x = tf.squeeze(x, axis=2) x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: assert hparams.task == "image" max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( (inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) return bn inputs_c = bn_inputs ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) if hparams.task == "translate": targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.task == "translate": res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Define losses losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} # Reshape image targets as 4d tensor. original_targets_shape = common_layers.shape_list(targets) if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d # Encoder decoder attention bias. ed_attention_bias = None # Input Encoder if present. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, "input_enc") # Encode targets to compute targets compressed. targets_c = compress_fn(targets, hparams, "compress") targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) # Following code creates an exponentially decaying variable based on which # we rescale the los values. batch_size = common_layers.shape_list(targets_c)[0] pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. # Call bottleneck layer to get the latents. # Returns embedded latents, discrete latents, loss and the embedding function. if hparams.mode != tf.estimator.ModeKeys.PREDICT: latents_dense, latents_discrete, extra_loss = (bottleneck_layer( targets_c, hparams)) extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond) # Call the autoregressive latent prediction model. _, latents_pred_loss = latent_prediction_model(inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float( cond) # Assign latent loss losses["latent_pred"] = latents_pred_loss losses["extra_loss"] = extra_loss else: latent_len = (hparams.img_len * hparams.img_len * hparams.num_latents) / 2**(hparams.num_compress_steps) embed = functools.partial(discretization.parametrized_unbottleneck, hparams=hparams) latents_dense = tf.zeros( [batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed, hparams) latents_dense = embed( tf.one_hot(cache, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_decoder = latents_dense if len(original_targets_shape) == 4: cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2)) latents_decoder = tf.reshape(latents_decoder, [ batch_size, cmp_img_len, cmp_img_len, hparams.num_latents * hparams.hidden_size ]) # Decompress either using 1D or 2D upconvs. latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress") # if we're operating in 2d space on images, then we're assuming that the # last dimension will not be a multiple of channels latents_decoder = tf.reshape( latents_decoder, shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder) masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) targets = mask * targets + (1.0 - mask) * latents_decoder else: targets = latents_decoder # reshape back to 4d here targets = tf.reshape(targets, original_targets_shape) if hparams.decode_autoregressive: # Transformer decoder, that goes from inputs->targets res = transformer_image_decoder(inputs, ed_attention_bias, targets, hparams, "decoder") else: res = targets # We'll start training the extra model of latents after mask_startup_steps. latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = { "extra": tf.constant(0.0), "latent_pred": tf.constant(0.0), "neg_q_entropy": tf.constant(0.0) } if hparams.do_ae: # flatten here original_targets = targets original_targets_shape = tf.shape(original_targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": if inputs is not None: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: max_targets_len_from_inputs = targets else: assert hparams.task == "image" max_targets_len_from_inputs = targets if hparams.word_shuffle: tf.logging.info("Using word shuffle with rate = {}".format( hparams.word_shuffle)) targets_idx = tf.range(start=0, limit=common_layers.shape_list(targets)[1], delta=1) targets_idx = tf.to_float(targets_idx) noise = tf.random_uniform( shape=common_layers.shape_list(targets_idx), minval=0, maxval=1 + hparams.word_shuffle) targets_idx += noise permutation = contrib.framework().argsort(targets_idx) targets_permuted = tf.gather(targets, indices=permutation, axis=1) targets = targets_permuted targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) # Add positional information targets_shape = common_layers.shape_list(targets) targets = tf.reshape( targets, [targets_shape[0], targets_shape[1], targets_shape[3]]) targets = common_attention.add_positional_embedding( targets, hparams.max_length, name="targets_position") targets = tf.reshape(targets, shape=targets_shape) if hparams.word_dropout: mask = tf.random_uniform(shape=common_layers.shape_list(targets), minval=0.0, maxval=1.0) targets_noisy = tf.where(mask > hparams.word_dropout, targets, tf.zeros_like(targets)) else: targets_noisy = targets targets_c = compress(targets_noisy, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = ( hparams.bottleneck(inputs=targets_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc")) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) # Scale by latent dimension for summary so we can compare across # batches. if _DO_SUMMARIES: tf.summary.scalar("latent_pred_loss_mean", tf.reduce_mean(latent_pred_loss)) if hparams.sum_over_latents: latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2]) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( tf.squared_difference(inputs_c, targets_c)) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") return bn inputs_c = bn_inputs() ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed, _ = hparams.bottleneck( inputs=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense d_shape = common_layers.shape_list(d) d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]]) d = common_attention.add_positional_embedding(d, hparams.max_length, name="latents_position") d = tf.reshape(d, shape=d_shape) # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if inputs is not None and hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) else: targets = d res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) # res was generated from padded targets, which means it has some extra # elements. These can cause shape problems when computing loss with respect to # the original (unpadded) targets. So we remove their extra elements here. res = res[:, :original_targets_shape[1], :, :] data_dim = common_layers.shape_list(res)[1] latent_dim = common_layers.shape_list(targets_c)[1] return res, losses, cache, data_dim, latent_dim
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """Auto-encoder using a Transformer decoder and a prior over latent sequences. Args: inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None. targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2 dimensions denoting sequence length. target_space: int. Used for encoding inputs under a target space id. hparams: HParams. cache: Tensor of shape [batch, length] or None. predict_mask: Tensor masking whether to use gold targets or predictions. Returns: decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting pre-logit activations. After a transformation (`top` in `T2TModel`), it is used with targets to compute the "training" (reconstruction) loss. losses: dict of str to Tensors. There are three loss terms: "extra", "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter two are Tensors of shape [batch]. cache: Tensor of shape [batch, length], either the same as cache, or newly computed if the cache input is None. """ original_targets_shape = common_layers.shape_list(targets) batch_size = original_targets_shape[0] if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d ed_attention_bias = None if inputs is not None: inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, name="input_encoder") losses = {"extra": 0., "extra_loss": 0., "latent_pred": 0.} if hparams.mode != tf.estimator.ModeKeys.PREDICT: targets_compressed = compress_fn(targets, hparams, name="compress") if hparams.mode == tf.estimator.ModeKeys.TRAIN: scale = common_layers.inverse_exp_decay(hparams.startup_steps) else: scale = 1.0 scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale)) latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer( targets_compressed, hparams) extra_loss = scale * tf.reduce_mean(extra_loss) _, latents_pred_loss = latent_prediction_model(inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss) latents_pred_loss *= tf.to_float(latent_time) # Apply dropout noise for each data point and time step. latents_dense_shape = common_layers.shape_list(latents_dense) latents_dense = tf.nn.dropout( latents_dense, keep_prob=1 - hparams.latent_dropout, noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1]) # TODO(trandustin): Can we combine extra and extra_loss? losses = { "extra": 0., "extra_loss": extra_loss, "latent_pred": latents_pred_loss } else: # Set the latent length, which is num_latents times the number of latent # pixels. The number of latent pixels is determined by a compression factor # on the number of image pixels. latent_len = ( (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps)) _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams) latents_dense = tf.zeros( [batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed_fn, hparams) cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits) latents_dense = embed_fn(cache_one_hot, hparams.hidden_size) if len(original_targets_shape) == 4: compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) latents_dense = tf.reshape(latents_dense, [ batch_size, compressed_img_len, compressed_img_len, hparams.num_latents * hparams.hidden_size ]) latents_dense = decompress_fn(latents_dense, hparams, name="decompress") latents_dense = tf.reshape( latents_dense, [-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask else: masking = common_layers.inverse_exp_decay( hparams.mask_startup_steps) targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) latents_dense = mask * targets + (1.0 - mask) * latents_dense latents_dense = tf.reshape(latents_dense, original_targets_shape) if hparams.decode_autoregressive: decoder_output = transformer_image_decoder(latents_dense, inputs, ed_attention_bias, hparams, name="decoder") else: decoder_output = latents_dense return decoder_output, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: assert hparams.task == "image" max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) if hparams.word_dropout: mask = tf.random_uniform(shape=common_layers.shape_list(targets), minval=0.0, maxval=1.0) targets_noisy = tf.where(mask > hparams.word_dropout, targets, tf.zeros_like(targets)) else: targets_noisy = targets targets_c = compress(targets_noisy, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) return bn inputs_c = bn_inputs ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample( latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """Auto-encoder using transformer decoder and prior over latents.""" losses = {"extra": 0., "latent_pred": 0.} # Reshape image targets as 4d tensor. original_targets_shape = common_layers.shape_list(targets) batch_size = original_targets_shape[0] if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d # Input Encoder if present. ed_attention_bias = None if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, "input_enc") # Encode targets to compute targets compressed. targets_c = compress_fn(targets, hparams, "compress") targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) # Following code creates an exponentially decaying variable based on which # we rescale the loss values. pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) # Call bottleneck layer, that takes encoder output and outputs the latents. # Returns embedded latents, discrete latent codes, loss. if hparams.mode != tf.estimator.ModeKeys.PREDICT: latents_dense, latents_discrete, extra_loss, _ = ( bottleneck_layer(targets_c, hparams)) extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond) _, latents_pred_loss = latent_prediction_model( inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond) latents_shape = common_layers.shape_list(latents_dense) latents_dense = tf.nn.dropout( latents_dense, 1 - hparams.latent_dropout, noise_shape=[latents_shape[0], latents_shape[1], 1]) losses["extra_loss"] = extra_loss losses["latent_pred"] = latents_pred_loss # We'll start training the extra model of latents after mask_startup_steps. latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) else: latent_len = ( hparams.img_len * hparams.img_len * hparams.num_latents) / 2**( hparams.num_compress_steps) _, _, _, embed = ( bottleneck_layer(targets_c, hparams)) latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed, hparams) latents_dense = embed( tf.one_hot(cache, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_decoder = latents_dense if len(original_targets_shape) == 4: compressed_img_len = hparams.img_len / 2**(hparams.num_compress_steps // 2) latents_decoder = tf.reshape(latents_decoder, [batch_size, compressed_img_len, compressed_img_len, hparams.num_latents * hparams.hidden_size]) latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress") # if we're operating in 2d space on images, then we're assuming that the # last dimension will not be a multiple of channels output = tf.reshape( latents_decoder, shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) output = mask * targets + (1.0 - mask) * output # reshape back to 4d here output = tf.reshape(output, original_targets_shape) if hparams.decode_autoregressive: # Transformer decoder, that goes from inputs->targets decoder_output = transformer_image_decoder( output, inputs, ed_attention_bias, hparams, "decoder") else: decoder_output = output return decoder_output, losses, cache
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """Auto-encoder using a Transformer decoder and a prior over latent sequences. Args: inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None. targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2 dimensions denoting sequence length. target_space: int. Used for encoding inputs under a target space id. hparams: tf.contrib.training.HParams. cache: Tensor of shape [batch, length] or None. predict_mask: Tensor masking whether to use gold targets or predictions. Returns: decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting pre-logit activations. After a transformation (`top` in `T2TModel`), it is used with targets to compute the "training" (reconstruction) loss. losses: dict of str to Tensors. There are three loss terms: "extra", "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter two are Tensors of shape [batch]. cache: Tensor of shape [batch, length], either the same as cache, or newly computed if the cache input is None. """ original_targets_shape = common_layers.shape_list(targets) batch_size = original_targets_shape[0] if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d ed_attention_bias = None if inputs is not None: inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, name="input_encoder") losses = {"extra": 0., "extra_loss": 0., "latent_pred": 0.} if hparams.mode != tf.estimator.ModeKeys.PREDICT: targets_compressed = compress_fn(targets, hparams, name="compress") if hparams.mode == tf.estimator.ModeKeys.TRAIN: scale = common_layers.inverse_exp_decay(hparams.startup_steps) else: scale = 1.0 scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale)) latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer( targets_compressed, hparams) extra_loss = scale * tf.reduce_mean(extra_loss) _, latents_pred_loss = latent_prediction_model( inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss) latents_pred_loss *= tf.to_float(latent_time) # Apply dropout noise for each data point and time step. latents_dense_shape = common_layers.shape_list(latents_dense) latents_dense = tf.nn.dropout( latents_dense, keep_prob=1 - hparams.latent_dropout, noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1]) # TODO(trandustin): Can we combine extra and extra_loss? losses = {"extra": 0., "extra_loss": extra_loss, "latent_pred": latents_pred_loss} else: # Set the latent length, which is num_latents times the number of latent # pixels. The number of latent pixels is determined by a compression factor # on the number of image pixels. latent_len = ((hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps)) _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams) latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed_fn, hparams) cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits) latents_dense = embed_fn(cache_one_hot, hparams.hidden_size) if len(original_targets_shape) == 4: compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) latents_dense = tf.reshape(latents_dense, [batch_size, compressed_img_len, compressed_img_len, hparams.num_latents * hparams.hidden_size]) latents_dense = decompress_fn(latents_dense, hparams, name="decompress") latents_dense = tf.reshape( latents_dense, [-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask else: masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) mask = tf.less(masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) latents_dense = mask * targets + (1.0 - mask) * latents_dense latents_dense = tf.reshape(latents_dense, original_targets_shape) if hparams.decode_autoregressive: decoder_output = transformer_image_decoder( latents_dense, inputs, ed_attention_bias, hparams, name="decoder") else: decoder_output = latents_dense return decoder_output, losses, cache
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """Auto-encoder using transformer decoder and prior over latents.""" losses = {"extra": 0., "latent_pred": 0.} # Reshape image targets as 4d tensor. original_targets_shape = common_layers.shape_list(targets) batch_size = original_targets_shape[0] if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d # Input Encoder if present. ed_attention_bias = None if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, "input_enc") # Encode targets to compute targets compressed. targets_c = compress_fn(targets, hparams, "compress") targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) # Following code creates an exponentially decaying variable based on which # we rescale the loss values. pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) # Call bottleneck layer, that takes encoder output and outputs the latents. # Returns embedded latents, discrete latent codes, loss. if hparams.mode != tf.estimator.ModeKeys.PREDICT: latents_dense, latents_discrete, extra_loss = ( bottleneck_layer(targets_c, hparams)) extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond) _, latents_pred_loss = latent_prediction_model( inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond) latents_shape = common_layers.shape_list(latents_dense) latents_dense = tf.nn.dropout( latents_dense, 1 - hparams.latent_dropout, noise_shape=[latents_shape[0], latents_shape[1], 1]) losses["extra_loss"] = extra_loss losses["latent_pred"] = latents_pred_loss # We'll start training the extra model of latents after mask_startup_steps. latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) else: latent_len = ( hparams.img_len * hparams.img_len * hparams.num_latents) / 2**( hparams.num_compress_steps) embed = functools.partial( discretization.parametrized_unbottleneck, hparams=hparams) latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed, hparams) latents_dense = embed( tf.one_hot(cache, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_decoder = latents_dense if len(original_targets_shape) == 4: compressed_img_len = hparams.img_len / 2**(hparams.num_compress_steps // 2) latents_decoder = tf.reshape(latents_decoder, [batch_size, compressed_img_len, compressed_img_len, hparams.num_latents * hparams.hidden_size]) latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress") # if we're operating in 2d space on images, then we're assuming that the # last dimension will not be a multiple of channels output = tf.reshape( latents_decoder, shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) output = mask * targets + (1.0 - mask) * output # reshape back to 4d here output = tf.reshape(output, original_targets_shape) if hparams.decode_autoregressive: # Transformer decoder, that goes from inputs->targets decoder_output = transformer_image_decoder( output, inputs, ed_attention_bias, hparams, "decoder") else: decoder_output = output return decoder_output, losses, cache