示例#1
0
def compress_self_attention_layer(x, hparams, name=None):
    """Attend function."""
    with tf.variable_scope(name, default_name="compress_self_attention"):
        x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
        y = common_attention.multihead_attention(
            common_layers.layer_preprocess(x, hparams), None, None,
            hparams.attention_key_channels or hparams.hidden_size,
            hparams.attention_value_channels or hparams.hidden_size,
            hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
        res = common_layers.layer_postprocess(x, y, hparams)
        return tf.reshape(res, xshape)
示例#2
0
def compress_self_attention_layer(x, hparams, name=None):
  """Attend function."""
  with tf.variable_scope(name, default_name="compress_self_attention"):
    x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
    y = common_attention.multihead_attention(
        common_layers.layer_preprocess(x, hparams),
        None,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, hparams.num_heads,
        hparams.attention_dropout)
    res = common_layers.layer_postprocess(x, y, hparams)
    return tf.reshape(res, xshape)
示例#3
0
def attend(x, source, hparams, name):
    """Attend function."""
    with tf.variable_scope(name):
        # x = tf.squeeze(x, axis=2)
        x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
        if len(source.get_shape()) > 3:
            source = tf.squeeze(source, axis=2)
        source = common_attention.add_timing_signal_1d(source)
        y = common_attention.multihead_attention(
            common_layers.layer_preprocess(x, hparams), source, None,
            hparams.attention_key_channels or hparams.hidden_size,
            hparams.attention_value_channels or hparams.hidden_size,
            hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
        res = common_layers.layer_postprocess(x, y, hparams)
        return tf.reshape(res, xshape)
示例#4
0
def attend(x, source, hparams, name):
  """Attend function."""
  with tf.variable_scope(name):
    # x = tf.squeeze(x, axis=2)
    x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
    if len(source.get_shape()) > 3:
      source = tf.squeeze(source, axis=2)
    source = common_attention.add_timing_signal_1d(source)
    y = common_attention.multihead_attention(
        common_layers.layer_preprocess(x, hparams),
        source,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, hparams.num_heads,
        hparams.attention_dropout)
    res = common_layers.layer_postprocess(x, y, hparams)
    return tf.reshape(res, xshape)
示例#5
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
    if hparams.do_ae:
        # flatten here
        original_targets_shape = tf.shape(targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        targets_c = compress(targets, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
                x=targets_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)
                losses["latent_pred"] = tf.reduce_mean(latent_pred_loss *
                                                       tf.to_float(cond))
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    (inputs_c - targets_c)**2) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _ = hparams.bottleneck(
                            x=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            name="vc",
                            mode=hparams.mode)
                    return bn

                inputs_c = bn_inputs
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _ = hparams.bottleneck(
                    x=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc",
                    mode=hparams.mode)
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed = hparams.bottleneck(
                    x=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
        pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
        latents_dense = tf.pad(latents_dense,
                               [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        if hparams.task == "translate":
            targets = tf.concat([tf.reverse(latents_dense, [1]), targets],
                                axis=1)

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.task == "translate":
            res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
示例#6
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Define losses
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    # Reshape image targets as 4d tensor.
    original_targets_shape = common_layers.shape_list(targets)
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    # Encoder decoder attention bias.
    ed_attention_bias = None

    # Input Encoder if present.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, "input_enc")

    # Encode targets to compute targets compressed.
    targets_c = compress_fn(targets, hparams, "compress")
    targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

    # Following code creates an exponentially decaying variable based on which
    # we rescale the los values.
    batch_size = common_layers.shape_list(targets_c)[0]
    pc = common_layers.inverse_exp_decay(hparams.startup_steps)
    pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
    cond = tf.less(tf.random_uniform([batch_size]), pc)

    # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
    # Call bottleneck layer to get the latents.
    # Returns embedded latents, discrete latents, loss and the embedding function.
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        latents_dense, latents_discrete, extra_loss = (bottleneck_layer(
            targets_c, hparams))
        extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

        # Call the autoregressive latent prediction model.
        _, latents_pred_loss = latent_prediction_model(inputs,
                                                       ed_attention_bias,
                                                       latents_discrete,
                                                       latents_dense,
                                                       hparams,
                                                       name="latent_pred")
        latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(
            cond)
        # Assign latent loss
        losses["latent_pred"] = latents_pred_loss
        losses["extra_loss"] = extra_loss
    else:
        latent_len = (hparams.img_len * hparams.img_len *
                      hparams.num_latents) / 2**(hparams.num_compress_steps)
        embed = functools.partial(discretization.parametrized_unbottleneck,
                                  hparams=hparams)
        latents_dense = tf.zeros(
            [batch_size, latent_len, 1, hparams.hidden_size])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs,
                                          ed_attention_bias, embed, hparams)
        latents_dense = embed(
            tf.one_hot(cache, depth=2**hparams.bottleneck_bits),
            hparams.hidden_size)

    latents_decoder = latents_dense
    if len(original_targets_shape) == 4:
        cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2))
        latents_decoder = tf.reshape(latents_decoder, [
            batch_size, cmp_img_len, cmp_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    # Decompress either using 1D or 2D upconvs.
    latents_decoder = decompress_fn(latents_decoder,
                                    hparams,
                                    name="decompress")
    # if we're operating in 2d space on images, then we're assuming that the
    # last dimension will not be a multiple of channels
    latents_decoder = tf.reshape(
        latents_decoder,
        shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder)
        masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        targets = mask * targets + (1.0 - mask) * latents_decoder
    else:
        targets = latents_decoder
    # reshape back to 4d here
    targets = tf.reshape(targets, original_targets_shape)
    if hparams.decode_autoregressive:
        # Transformer decoder, that goes from inputs->targets
        res = transformer_image_decoder(inputs, ed_attention_bias, targets,
                                        hparams, "decoder")
    else:
        res = targets

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
示例#7
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {
        "extra": tf.constant(0.0),
        "latent_pred": tf.constant(0.0),
        "neg_q_entropy": tf.constant(0.0)
    }
    if hparams.do_ae:
        # flatten here
        original_targets = targets
        original_targets_shape = tf.shape(original_targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            if inputs is not None:
                max_targets_len_from_inputs = tf.concat([inputs, inputs],
                                                        axis=1)
            else:
                max_targets_len_from_inputs = targets
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        if hparams.word_shuffle:
            tf.logging.info("Using word shuffle with rate = {}".format(
                hparams.word_shuffle))
            targets_idx = tf.range(start=0,
                                   limit=common_layers.shape_list(targets)[1],
                                   delta=1)
            targets_idx = tf.to_float(targets_idx)
            noise = tf.random_uniform(
                shape=common_layers.shape_list(targets_idx),
                minval=0,
                maxval=1 + hparams.word_shuffle)
            targets_idx += noise
            permutation = contrib.framework().argsort(targets_idx)
            targets_permuted = tf.gather(targets, indices=permutation, axis=1)
            targets = targets_permuted
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        # Add positional information
        targets_shape = common_layers.shape_list(targets)
        targets = tf.reshape(
            targets, [targets_shape[0], targets_shape[1], targets_shape[3]])
        targets = common_attention.add_positional_embedding(
            targets, hparams.max_length, name="targets_position")
        targets = tf.reshape(targets, shape=targets_shape)
        if hparams.word_dropout:
            mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                                     minval=0.0,
                                     maxval=1.0)
            targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                                     tf.zeros_like(targets))
        else:
            targets_noisy = targets

        targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
                hparams.bottleneck(inputs=targets_c,
                                   filter_size=hparams.compress_filter_size,
                                   mode=hparams.mode,
                                   name="vc"))
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)

                # Scale by latent dimension for summary so we can compare across
                # batches.
                if _DO_SUMMARIES:
                    tf.summary.scalar("latent_pred_loss_mean",
                                      tf.reduce_mean(latent_pred_loss))
                if hparams.sum_over_latents:
                    latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

                losses["latent_pred"] = tf.reduce_mean(
                    latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
                losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    tf.squared_difference(inputs_c, targets_c)) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _, _ = hparams.bottleneck(
                            inputs=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            mode=hparams.mode,
                            name="vc")
                    return bn

                inputs_c = bn_inputs()
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _, _ = hparams.bottleneck(
                    inputs=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    mode=hparams.mode,
                    name="vc")
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed, _ = hparams.bottleneck(
                    inputs=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        d_shape = common_layers.shape_list(d)
        d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
        d = common_attention.add_positional_embedding(d,
                                                      hparams.max_length,
                                                      name="latents_position")
        d = tf.reshape(d, shape=d_shape)

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if inputs is not None and hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        else:
            targets = d

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)

    # res was generated from padded targets, which means it has some extra
    # elements. These can cause shape problems when computing loss with respect to
    # the original (unpadded) targets. So we remove their extra elements here.
    res = res[:, :original_targets_shape[1], :, :]

    data_dim = common_layers.shape_list(res)[1]
    latent_dim = common_layers.shape_list(targets_c)[1]
    return res, losses, cache, data_dim, latent_dim
示例#8
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """Auto-encoder using a Transformer decoder and a prior over latent sequences.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None.
    targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2
      dimensions denoting sequence length.
    target_space: int. Used for encoding inputs under a target space id.
    hparams: HParams.
    cache: Tensor of shape [batch, length] or None.
    predict_mask: Tensor masking whether to use gold targets or predictions.

  Returns:
    decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting
      pre-logit activations. After a transformation (`top` in `T2TModel`), it is
      used with targets to compute the "training" (reconstruction) loss.
    losses: dict of str to Tensors. There are three loss terms: "extra",
      "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter
      two are Tensors of shape [batch].
    cache: Tensor of shape [batch, length], either the same as cache, or newly
      computed if the cache input is None.
  """
    original_targets_shape = common_layers.shape_list(targets)
    batch_size = original_targets_shape[0]
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    ed_attention_bias = None
    if inputs is not None:
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, name="input_encoder")

    losses = {"extra": 0., "extra_loss": 0., "latent_pred": 0.}
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        targets_compressed = compress_fn(targets, hparams, name="compress")

        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            scale = common_layers.inverse_exp_decay(hparams.startup_steps)
        else:
            scale = 1.0
        scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale))

        latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer(
            targets_compressed, hparams)
        extra_loss = scale * tf.reduce_mean(extra_loss)

        _, latents_pred_loss = latent_prediction_model(inputs,
                                                       ed_attention_bias,
                                                       latents_discrete,
                                                       latents_dense,
                                                       hparams,
                                                       name="latent_pred")
        latent_time = tf.less(hparams.mask_startup_steps,
                              tf.to_int32(tf.train.get_global_step()))
        latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss)
        latents_pred_loss *= tf.to_float(latent_time)

        # Apply dropout noise for each data point and time step.
        latents_dense_shape = common_layers.shape_list(latents_dense)
        latents_dense = tf.nn.dropout(
            latents_dense,
            keep_prob=1 - hparams.latent_dropout,
            noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1])

        # TODO(trandustin): Can we combine extra and extra_loss?
        losses = {
            "extra": 0.,
            "extra_loss": extra_loss,
            "latent_pred": latents_pred_loss
        }
    else:
        # Set the latent length, which is num_latents times the number of latent
        # pixels. The number of latent pixels is determined by a compression factor
        # on the number of image pixels.
        latent_len = (
            (hparams.img_len * hparams.img_len * hparams.num_latents) /
            (2**hparams.num_compress_steps))
        _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams)
        latents_dense = tf.zeros(
            [batch_size, latent_len, 1, hparams.hidden_size])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs,
                                          ed_attention_bias, embed_fn, hparams)
        cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
        latents_dense = embed_fn(cache_one_hot, hparams.hidden_size)

    if len(original_targets_shape) == 4:
        compressed_img_len = (hparams.img_len //
                              2**(hparams.num_compress_steps // 2))
        latents_dense = tf.reshape(latents_dense, [
            batch_size, compressed_img_len, compressed_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    latents_dense = decompress_fn(latents_dense, hparams, name="decompress")
    latents_dense = tf.reshape(
        latents_dense,
        [-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        else:
            masking = common_layers.inverse_exp_decay(
                hparams.mask_startup_steps)
        targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        latents_dense = mask * targets + (1.0 - mask) * latents_dense

    latents_dense = tf.reshape(latents_dense, original_targets_shape)
    if hparams.decode_autoregressive:
        decoder_output = transformer_image_decoder(latents_dense,
                                                   inputs,
                                                   ed_attention_bias,
                                                   hparams,
                                                   name="decoder")
    else:
        decoder_output = latents_dense
    return decoder_output, losses, cache
示例#9
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  if inputs is not None:
    batch_size = common_layers.shape_list(inputs)[0]
  else:
    batch_size = common_layers.shape_list(targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
    inputs_ex, ed_ex = inputs, ed
  else:
    ed, inputs_ex, ed_ex = None, None, None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    # flatten here
    original_targets_shape = tf.shape(targets)
    if hparams.task == "image":
      cia.maybe_reshape_4d_to_3d(targets)
    if hparams.task == "translate":
      max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    else:
      assert hparams.task == "image"
      max_targets_len_from_inputs = targets
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    if hparams.word_dropout:
      mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                               minval=0.0, maxval=1.0)
      targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                               tf.zeros_like(targets))
    else:
      targets_noisy = targets
    targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
          x=targets_c,
          filter_size=hparams.compress_filter_size,
          name="vc",
          mode=hparams.mode)
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps)
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            inputs_ex, ed_ex,
            embed(latents_discrete), hparams, "extra",
            task="translate")
        _, latent_pred_loss = ae_latent_softmax(
            latents_pred, tf.stop_gradient(latents_discrete), hparams)
        losses["latent_pred"] = tf.reduce_mean(
            latent_pred_loss * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = hparams.bottleneck(
                x=inputs_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
          return bn
        inputs_c = bn_inputs
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = hparams.bottleneck(
            x=inputs_c,
            filter_size=hparams.compress_filter_size,
            name="vc",
            mode=hparams.mode)
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = hparams.bottleneck(
            x=targets_c, filter_size=hparams.compress_filter_size, name="vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(
              latents_dense, inputs_ex, ed_ex, embed, 16, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # decompressing the dense latents
    for i in range(hparams.num_compress_steps):
      j = hparams.num_compress_steps - i - 1
      d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
      if hparams.do_attend_decompress:
        d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
      d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
      masking *= common_layers.inverse_exp_decay(
          hparams.mask_startup_steps // 4)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * hparams.unmasked_percentage
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.use_predict_mask:
        masking = predict_mask
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)

      # targets is always [batch, length, 1, depth]
      targets = mask * targets + (1.0 - mask) * d
      # reshape back to 4d here
      if hparams.task == "image":
        targets = tf.reshape(targets, original_targets_shape)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder",
                           causal=hparams.causal)
  if hparams.do_ae:
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        # return residual_conv(res, 1, (5, 1), hparams, "refine")
        r, _ = encode(tf.squeeze(res, axis=[2]),
                      target_space, hparams, "refine_enc")
        return tf.expand_dims(r, axis=2)
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training the extra model of latents after mask_startup_steps.
    nonlatent_steps = hparams.mask_startup_steps
    latent_time = tf.less(nonlatent_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  return res, losses, cache
示例#10
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """Auto-encoder using transformer decoder and prior over latents."""
  losses = {"extra": 0., "latent_pred": 0.}

  # Reshape image targets as 4d tensor.
  original_targets_shape = common_layers.shape_list(targets)
  batch_size = original_targets_shape[0]
  if len(original_targets_shape) == 4:
    compress_fn = compress_encoder_2d
    decompress_fn = decompress_decoder_2d
  else:
    compress_fn = compress_encoder_1d
    decompress_fn = decompress_decoder_1d

  # Input Encoder if present.
  ed_attention_bias = None
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed_attention_bias = transformer_text_encoder(
        inputs, target_space, hparams, "input_enc")

  # Encode targets to compute targets compressed.
  targets_c = compress_fn(targets, hparams, "compress")
  targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

  # Following code creates an exponentially decaying variable based on which
  # we rescale the loss values.
  pc = common_layers.inverse_exp_decay(hparams.startup_steps)
  pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
  cond = tf.less(tf.random_uniform([batch_size]), pc)

  # Call bottleneck layer, that takes encoder output and outputs the latents.
  # Returns embedded latents, discrete latent codes, loss.
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    latents_dense, latents_discrete, extra_loss, _ = (
        bottleneck_layer(targets_c, hparams))
    extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

    _, latents_pred_loss = latent_prediction_model(
        inputs,
        ed_attention_bias,
        latents_discrete,
        latents_dense,
        hparams,
        name="latent_pred")
    latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond)

    latents_shape = common_layers.shape_list(latents_dense)
    latents_dense = tf.nn.dropout(
        latents_dense, 1 - hparams.latent_dropout,
        noise_shape=[latents_shape[0], latents_shape[1], 1])

    losses["extra_loss"] = extra_loss
    losses["latent_pred"] = latents_pred_loss

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  else:
    latent_len = (
        hparams.img_len * hparams.img_len * hparams.num_latents) / 2**(
            hparams.num_compress_steps)
    _, _, _, embed = (
        bottleneck_layer(targets_c, hparams))
    latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias,
                                    embed, hparams)
    latents_dense = embed(
        tf.one_hot(cache, depth=2**hparams.bottleneck_bits),
        hparams.hidden_size)

  latents_decoder = latents_dense
  if len(original_targets_shape) == 4:
    compressed_img_len = hparams.img_len / 2**(hparams.num_compress_steps // 2)
    latents_decoder = tf.reshape(latents_decoder,
                                 [batch_size,
                                  compressed_img_len,
                                  compressed_img_len,
                                  hparams.num_latents * hparams.hidden_size])

  latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress")
  # if we're operating in 2d space on images, then we're assuming that the
  # last dimension will not be a multiple of channels
  output = tf.reshape(
      latents_decoder,
      shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

  if hparams.use_gold_targets:
    masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      masking = predict_mask
    mask = tf.less(masking, tf.random_uniform(
        common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 2)
    output = mask * targets + (1.0 - mask) * output

  # reshape back to 4d here
  output = tf.reshape(output, original_targets_shape)
  if hparams.decode_autoregressive:
    # Transformer decoder, that goes from inputs->targets
    decoder_output = transformer_image_decoder(
        output, inputs, ed_attention_bias, hparams, "decoder")
  else:
    decoder_output = output
  return decoder_output, losses, cache
示例#11
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """Auto-encoder using a Transformer decoder and a prior over latent sequences.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None.
    targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2
      dimensions denoting sequence length.
    target_space: int. Used for encoding inputs under a target space id.
    hparams: tf.contrib.training.HParams.
    cache: Tensor of shape [batch, length] or None.
    predict_mask: Tensor masking whether to use gold targets or predictions.

  Returns:
    decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting
      pre-logit activations. After a transformation (`top` in `T2TModel`), it is
      used with targets to compute the "training" (reconstruction) loss.
    losses: dict of str to Tensors. There are three loss terms: "extra",
      "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter
      two are Tensors of shape [batch].
    cache: Tensor of shape [batch, length], either the same as cache, or newly
      computed if the cache input is None.
  """
  original_targets_shape = common_layers.shape_list(targets)
  batch_size = original_targets_shape[0]
  if len(original_targets_shape) == 4:
    compress_fn = compress_encoder_2d
    decompress_fn = decompress_decoder_2d
  else:
    compress_fn = compress_encoder_1d
    decompress_fn = decompress_decoder_1d

  ed_attention_bias = None
  if inputs is not None:
    inputs, ed_attention_bias = transformer_text_encoder(
        inputs, target_space, hparams, name="input_encoder")

  losses = {"extra": 0.,
            "extra_loss": 0.,
            "latent_pred": 0.}
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    targets_compressed = compress_fn(targets, hparams, name="compress")

    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      scale = common_layers.inverse_exp_decay(hparams.startup_steps)
    else:
      scale = 1.0
    scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale))

    latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer(
        targets_compressed, hparams)
    extra_loss = scale * tf.reduce_mean(extra_loss)

    _, latents_pred_loss = latent_prediction_model(
        inputs, ed_attention_bias, latents_discrete, latents_dense, hparams,
        name="latent_pred")
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss)
    latents_pred_loss *= tf.to_float(latent_time)

    # Apply dropout noise for each data point and time step.
    latents_dense_shape = common_layers.shape_list(latents_dense)
    latents_dense = tf.nn.dropout(
        latents_dense,
        keep_prob=1 - hparams.latent_dropout,
        noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1])

    # TODO(trandustin): Can we combine extra and extra_loss?
    losses = {"extra": 0.,
              "extra_loss": extra_loss,
              "latent_pred": latents_pred_loss}
  else:
    # Set the latent length, which is num_latents times the number of latent
    # pixels. The number of latent pixels is determined by a compression factor
    # on the number of image pixels.
    latent_len = ((hparams.img_len * hparams.img_len * hparams.num_latents) /
                  (2**hparams.num_compress_steps))
    _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams)
    latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense,
                                    inputs,
                                    ed_attention_bias,
                                    embed_fn,
                                    hparams)
    cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
    latents_dense = embed_fn(cache_one_hot, hparams.hidden_size)

  if len(original_targets_shape) == 4:
    compressed_img_len = (hparams.img_len //
                          2**(hparams.num_compress_steps // 2))
    latents_dense = tf.reshape(latents_dense,
                               [batch_size,
                                compressed_img_len,
                                compressed_img_len,
                                hparams.num_latents * hparams.hidden_size])

  latents_dense = decompress_fn(latents_dense, hparams, name="decompress")
  latents_dense = tf.reshape(
      latents_dense,
      [-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

  if hparams.use_gold_targets:
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      masking = predict_mask
    else:
      masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
    targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)
    mask = tf.less(masking,
                   tf.random_uniform(common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 2)
    latents_dense = mask * targets + (1.0 - mask) * latents_dense

  latents_dense = tf.reshape(latents_dense, original_targets_shape)
  if hparams.decode_autoregressive:
    decoder_output = transformer_image_decoder(
        latents_dense, inputs, ed_attention_bias, hparams, name="decoder")
  else:
    decoder_output = latents_dense
  return decoder_output, losses, cache
示例#12
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """Auto-encoder using transformer decoder and prior over latents."""
  losses = {"extra": 0., "latent_pred": 0.}

  # Reshape image targets as 4d tensor.
  original_targets_shape = common_layers.shape_list(targets)
  batch_size = original_targets_shape[0]
  if len(original_targets_shape) == 4:
    compress_fn = compress_encoder_2d
    decompress_fn = decompress_decoder_2d
  else:
    compress_fn = compress_encoder_1d
    decompress_fn = decompress_decoder_1d

  # Input Encoder if present.
  ed_attention_bias = None
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed_attention_bias = transformer_text_encoder(
        inputs, target_space, hparams, "input_enc")

  # Encode targets to compute targets compressed.
  targets_c = compress_fn(targets, hparams, "compress")
  targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

  # Following code creates an exponentially decaying variable based on which
  # we rescale the loss values.
  pc = common_layers.inverse_exp_decay(hparams.startup_steps)
  pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
  cond = tf.less(tf.random_uniform([batch_size]), pc)

  # Call bottleneck layer, that takes encoder output and outputs the latents.
  # Returns embedded latents, discrete latent codes, loss.
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    latents_dense, latents_discrete, extra_loss = (
        bottleneck_layer(targets_c, hparams))
    extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

    _, latents_pred_loss = latent_prediction_model(
        inputs,
        ed_attention_bias,
        latents_discrete,
        latents_dense,
        hparams,
        name="latent_pred")
    latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond)

    latents_shape = common_layers.shape_list(latents_dense)
    latents_dense = tf.nn.dropout(
        latents_dense, 1 - hparams.latent_dropout,
        noise_shape=[latents_shape[0], latents_shape[1], 1])

    losses["extra_loss"] = extra_loss
    losses["latent_pred"] = latents_pred_loss

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  else:
    latent_len = (
        hparams.img_len * hparams.img_len * hparams.num_latents) / 2**(
            hparams.num_compress_steps)
    embed = functools.partial(
        discretization.parametrized_unbottleneck, hparams=hparams)
    latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias,
                                    embed, hparams)
    latents_dense = embed(
        tf.one_hot(cache, depth=2**hparams.bottleneck_bits),
        hparams.hidden_size)

  latents_decoder = latents_dense
  if len(original_targets_shape) == 4:
    compressed_img_len = hparams.img_len / 2**(hparams.num_compress_steps // 2)
    latents_decoder = tf.reshape(latents_decoder,
                                 [batch_size,
                                  compressed_img_len,
                                  compressed_img_len,
                                  hparams.num_latents * hparams.hidden_size])

  latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress")
  # if we're operating in 2d space on images, then we're assuming that the
  # last dimension will not be a multiple of channels
  output = tf.reshape(
      latents_decoder,
      shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

  if hparams.use_gold_targets:
    masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      masking = predict_mask
    mask = tf.less(masking, tf.random_uniform(
        common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 2)
    output = mask * targets + (1.0 - mask) * output

  # reshape back to 4d here
  output = tf.reshape(output, original_targets_shape)
  if hparams.decode_autoregressive:
    # Transformer decoder, that goes from inputs->targets
    decoder_output = transformer_image_decoder(
        output, inputs, ed_attention_bias, hparams, "decoder")
  else:
    decoder_output = output
  return decoder_output, losses, cache