def ae_transformer_internal(inputs, targets, target_space, hparams):
    """AE Transformer, main step used for training."""
    with tf.variable_scope("ae_transformer"):
        # Prepare inputs, targets, k.
        k = 2**hparams.num_compress_steps
        _, targets = common_layers.pad_to_same_length(
            targets, targets, final_length_divisible_by=k)
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")

        # Compress and ae.
        ae, hot, kl = ae_compress(targets, hparams.is_2d, hparams, "ae")
        tf.summary.histogram("hot", tf.reshape(tf.argmax(hot, axis=-1), [-1]))
        emb = ae_embed(hot, hparams, "ae", reuse=True)

        # Compress context and run autoregressive decoder on emb-hot.
        emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2)
        dec_c = decode(None, None, emb_flat, inputs, ed, hparams)
        dec_c = tf.reshape(dec_c, tf.shape(emb))
        c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context")
        reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(labels=hot,
                                                                   logits=c_z)
        # If not training, use the predicted z instead of the autoregressive one.
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size)

        # Decompress, pass for ae loss.
        z = ae_decompress(emb, ae, targets, hparams.is_2d, hparams, "ae")
        kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8))
        reconstruct_loss *= common_layers.inverse_exp_decay(
            hparams.startup_steps)
        losses = {"kl": kl, "reconstruction": reconstruct_loss}
        return z, losses
def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None):
    """Decompress from z, leaking from ae."""
    with tf.variable_scope(name + "_decompress", reuse=reuse):
        # Leak at the beginning to help train.
        z = mix(z, ae, hparams.startup_steps)
        prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8
        prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 1.0
        z = tf.cond(tf.less(tf.random_uniform([]), prob_z), lambda: z,
                    lambda: ae)

        # Dropout for better autoencoding.
        z = tf.nn.dropout(z, keep_prob=1.0 - hparams.z_dropout)

        # Decompress.
        d = z
        for i in xrange(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            d = decompress_step(d, None, hparams, i > 0, is_2d,
                                "decompress_%d" % j)

        k = 2**hparams.num_compress_steps
        z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size])
        x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size])
        d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size])
        dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
        z = tf.reshape(dec_batch, [-1, tf.shape(x)[1], 1, hparams.hidden_size])

    return z
def dae(x, hparams, name):
  with tf.variable_scope(name):
    m = tf.layers.dense(x, hparams.v_size, name="mask")
    if hparams.softmax_k > 0:
      m, kl = top_k_softmax(m, hparams.softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)
    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = hparams.kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)
    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9),
                          lambda: temperature,
                          lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = - tf.reduce_max(logsm, axis=-1)
    if _DO_SUMMARIES:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size))
    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = - tf.reduce_mean(d_variance)
    ret = s
    if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
Example #4
0
def vae_transformer_internal(inputs, targets, target_space, hparams):
  """VAE Transformer, main step used for training."""
  with tf.variable_scope("vae_transformer"):
    # Prepare inputs, targets, and k.
    inputs = common_layers.flatten4d3d(inputs)
    input_len = tf.shape(inputs)[1]  # Double input size to cover targets.
    inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]])
    inputs.set_shape([None, None, hparams.hidden_size])
    targets = common_layers.flatten4d3d(targets)
    k = 2**hparams.num_compress_steps
    inputs, targets = common_layers.pad_to_same_length(
        inputs, targets, final_length_divisible_by=k)
    inputs = encode(inputs, target_space, hparams, "input_enc")

    # Compress and vae.
    z, kl_loss, _, _ = vae_compress(tf.expand_dims(targets, axis=2),
                                    tf.expand_dims(inputs, axis=2),
                                    hparams, "vae_compress", "vae_decompress")

    # Join z with inputs, run decoder.
    to_decode = common_layers.conv_block(
        tf.concat([z, tf.expand_dims(inputs, axis=2)], axis=3),
        hparams.hidden_size, [((1, 1), (1, 1))], name="join_z")
    ret = encode(tf.squeeze(to_decode, axis=2), target_space, hparams, "dec")

    # For experiments with one-sided decoder:
    # decoder_in = tf.squeeze(to_decode, axis=2)
    # (decoder_input, decoder_self_attention_bias) = (
    #     transformer.transformer_prepare_decoder(decoder_in, hparams))
    # ret = transformer.transformer_decoder(
    #     decoder_input, inputs, decoder_self_attention_bias, None, hparams)

    kl_loss *= common_layers.inverse_exp_decay(hparams.kl_warmup_steps) * 3.0
    losses = {"kl": kl_loss}
    return tf.expand_dims(ret, axis=2), losses
def dae(x, hparams, name):
  with tf.variable_scope(name):
    m = tf.layers.dense(x, hparams.v_size, name="mask")
    if hparams.softmax_k > 0:
      m, kl = top_k_softmax(m, hparams.softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)
    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = hparams.kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)
    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9),
                          lambda: temperature,
                          lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = - tf.reduce_max(logsm, axis=-1)
    if _DO_SUMMARIES:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size))
    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = - tf.reduce_mean(d_variance)
    ret = s
    if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
Example #6
0
    def model_fn_body(self, features):
        hparams = self._hparams
        # TODO(rshin): Give identity_module lower weight by default.
        multi_conv = multi_conv_module(
            kernel_sizes=[(3, 3), (5, 5), (7, 7)], seps=[0, 1])
        conv_modules = [multi_conv, identity_module]
        activation_modules = [
            identity_module, lambda x, _: tf.nn.relu(x),
            lambda x, _: tf.nn.elu(x),
            lambda x, _: tf.tanh(x)
        ]
        norm_modules = [identity_module, layernorm_module, noamnorm_module]
        binary_modules = [
            first_binary_module, second_binary_module, sum_binary_module,
            shakeshake_binary_module
        ]
        inputs = features["inputs"]

        def run_unary(x, name):
            """A single step of unary modules."""
            x_shape = x.get_shape()
            with tf.variable_scope(name):
                with tf.variable_scope("norm"):
                    x = run_unary_modules(norm_modules, x, hparams)
                    x.set_shape(x_shape)
                with tf.variable_scope("activation"):
                    x = run_unary_modules(activation_modules, x, hparams)
                    x.set_shape(x_shape)
                with tf.variable_scope("conv"):
                    x = run_unary_modules(conv_modules, x, hparams)
                    x.set_shape(x_shape)
            return tf.nn.dropout(x, 1.0 - hparams.dropout), batch_deviation(x)

        cur1, cur2, cur3, extra_loss = inputs, inputs, inputs, 0.0
        cur_shape = inputs.get_shape()
        for i in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % i):
                cur1, loss1 = run_unary(cur1, "unary1")
                cur2, loss2 = run_unary(cur2, "unary2")
                cur3, loss3 = run_unary(cur2, "unary3")
                extra_loss += (loss1 + loss2 + loss3) / float(
                    hparams.num_hidden_layers)
                with tf.variable_scope("binary1"):
                    next1 = run_binary_modules(binary_modules, cur1, cur2,
                                               hparams)
                    next1.set_shape(cur_shape)
                with tf.variable_scope("binary2"):
                    next2 = run_binary_modules(binary_modules, cur1, cur3,
                                               hparams)
                    next2.set_shape(cur_shape)
                with tf.variable_scope("binary3"):
                    next3 = run_binary_modules(binary_modules, cur2, cur3,
                                               hparams)
                    next3.set_shape(cur_shape)
                cur1, cur2, cur3 = next1, next2, next3

        anneal = common_layers.inverse_exp_decay(hparams.anneal_until)
        extra_loss *= hparams.batch_deviation_loss_factor * anneal
        return cur1, extra_loss
 def get_exp_sched_prob():
     """Inverse decay exponential to mix datasets."""
     with tf.control_dependencies([problem_step.assign_add(1)]):
         inv_exp_decay = common_layers.inverse_exp_decay(
             max_step=hparams.multiproblem_schedule_max_examples,
             min_value=1e-4,
             step=tf.to_float(problem_step))
         # inv_exp_decay is bounded above by 1.0
         return inv_exp_decay * hparams.multiproblem_schedule_threshold
    def get_scheduled_sample_func(self, batch_size):
        """Creates a function for scheduled sampling based on given hparams."""
        with tf.variable_scope("scheduled_sampling_func", reuse=tf.AUTO_REUSE):
            iter_num = self.get_iteration_num()

            # Simple function to bypass scheduled sampling in gt or pred only modes.
            def scheduled_sampling_simple(ground_truth_x, generated_x,
                                          batch_size, scheduled_sample_var):
                del batch_size
                if scheduled_sample_var:
                    return ground_truth_x
                return generated_x

            mode = self.hparams.scheduled_sampling_mode
            if mode == "ground_truth_only":
                scheduled_sampling_func = scheduled_sampling_simple
                scheduled_sampling_func_var = True
            elif mode == "prediction_only":
                scheduled_sampling_func = scheduled_sampling_simple
                scheduled_sampling_func_var = False
            elif mode == "prob":
                decay_steps = self.hparams.scheduled_sampling_decay_steps
                probability = tf.train.polynomial_decay(
                    1.0, iter_num, decay_steps, 0.0)
                scheduled_sampling_func = common_video.scheduled_sample_prob
                scheduled_sampling_func_var = probability
            elif mode == "prob_inverse_exp":
                decay_steps = self.hparams.scheduled_sampling_decay_steps
                probability = common_layers.inverse_exp_decay(decay_steps,
                                                              step=iter_num)
                probability *= self.hparams.scheduled_sampling_max_prob
                probability = 1.0 - probability
                scheduled_sampling_func = common_video.scheduled_sample_prob
                scheduled_sampling_func_var = probability
            elif mode == "count":
                # Calculate number of ground-truth frames to pass in.
                k = self.hparams.scheduled_sampling_k
                num_ground_truth = tf.to_int32(
                    tf.round(
                        tf.to_float(batch_size) *
                        (k /
                         (k + tf.exp(tf.to_float(iter_num) / tf.to_float(k)))))
                )
                scheduled_sampling_func = common_video.scheduled_sample_count
                scheduled_sampling_func_var = num_ground_truth
            else:
                raise ValueError("unknown scheduled sampling method: %s" %
                                 mode)

            if isinstance(scheduled_sampling_func_var, tf.Tensor):
                tf.summary.scalar("scheduled_sampling_var",
                                  scheduled_sampling_func_var)
            partial_func = partial(
                scheduled_sampling_func,
                batch_size=batch_size,
                scheduled_sample_var=scheduled_sampling_func_var)
            return partial_func
Example #9
0
def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin"):
  if mode == "lin":
    alpha_p = common_layers.inverse_lin_decay(steps) + 0.001
  else:
    alpha_p = common_layers.inverse_exp_decay(steps) + 0.001
  alpha_p = alpha_p * (max_prob - min_prob) + min_prob
  alpha = tf.random_uniform(tf.shape(x1))
  alpha = tf.to_float(tf.less(alpha, alpha_p))
  return alpha * x1 + (1.0 - alpha) * x2
Example #10
0
def bottleneck(x, hparams, filter_size, name):
  """Bottleneck."""
  def embed1(x):
    if hparams.bottleneck_kind == "semhash":
      c = int_to_bit(x, c_size)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      return h1a + h1b
    elif hparams.bottleneck_kind == "gumbel-softmax":
      hot = tf.one_hot(x, hparams.v_size)
      with tf.variable_scope(name, reuse=True):
        return tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")

  def embed(x):
    with tf.variable_scope(name, reuse=True):
      h1 = embed1(x)
      h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
      res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res

  with tf.variable_scope(name):
    c_size = hparams.c_size
    l = tf.constant(0.0)
    if hparams.bottleneck_kind == "dense":
      c = tf.layers.dense(x, c_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "semhash":
      c = tf.layers.dense(x, c_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      # l = tf.reduce_mean(y_clean * (1.0 - y_clean))
      if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
        dev = hparams.noise_dev
        noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2)
      pd *= hparams.d_mix
      pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.cond(tf.less(tf.random_uniform([]), pd),
                  lambda: y_discrete, lambda: y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, c_size)
    if hparams.bottleneck_kind == "gumbel-softmax":
      _, hot, l = dae(x, hparams, name)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
    res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res, c, l, embed
Example #11
0
 def get_exp_sched_prob():
   """Inverse decay exponential to mix datasets."""
   with tf.control_dependencies([problem_step.assign_add(1)]):
     inv_exp_decay = common_layers.inverse_exp_decay(
         max_step=hparams.multiproblem_schedule_max_examples,
         min_value=1e-4,
         step=tf.to_float(problem_step)
     )
     # inv_exp_decay is bounded above by 1.0
     return inv_exp_decay * hparams.multiproblem_schedule_threshold
Example #12
0
def scheduled_sampling(hparams, problem_hparams, dp, sharded_logits, losses,
                       sharded_features, transformed_features, model):
    """Scheduled sampling."""
    target_modality = problem_hparams.target_modality

    def sample(x):
        """Multinomial sampling from a n-dimensional tensor."""
        vocab_size = target_modality.top_dimensionality
        samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1)
        reshaped_samples = tf.reshape(samples,
                                      common_layers.shape_list(x)[:-1])
        return tf.to_int32(reshaped_samples)

    def mix_gold_sampled(gold_targets, sampled_targets):
        return tf.where(
            tf.less(
                tf.random_uniform(common_layers.shape_list(sampled_targets)),
                hparams.scheduled_sampling_gold_mixin_prob), gold_targets,
            sampled_targets)

    def sampled_results():
        """Generate scheduled sampling results."""
        sampled_targets = dp(sample, sharded_logits)
        new_targets = dp(mix_gold_sampled, sharded_features["targets"],
                         sampled_targets)
        new_features = transformed_features
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            with tf.variable_scope(target_modality.name):
                new_features[
                    "targets"] = target_modality.targets_bottom_sharded(
                        new_targets, dp)
            with tf.variable_scope("body"):
                body_outputs, losses = model.model_fn_sharded(new_features)
                if not isinstance(losses,
                                  dict):  # If it's a single extra loss.
                    losses = {"extra": losses}
            with tf.variable_scope(target_modality.name):
                new_sharded_logits = target_modality.top_sharded(
                    body_outputs, sharded_features["targets"], dp)
                if "training" not in losses:
                    training_loss = target_modality.loss_sharded(
                        sharded_logits, sharded_features["targets"], dp)
                    training_loss *= problem_hparams.loss_multiplier
                    losses["training"] = training_loss
        return new_sharded_logits, losses

    # Run the above conditionally.
    prob = hparams.scheduled_sampling_prob
    prob *= common_layers.inverse_exp_decay(
        hparams.scheduled_sampling_warmup_steps, min_value=0.001)
    sharded_logits, losses = tf.cond(tf.less(tf.random_uniform([]),
                                             prob), sampled_results, lambda:
                                     (sharded_logits, losses))
    return sharded_logits, losses
Example #13
0
  def body(self, features):
    hparams = self._hparams
    # TODO(rshin): Give identity_module lower weight by default.
    multi_conv = multi_conv_module(
        kernel_sizes=[(3, 3), (5, 5), (7, 7)], seps=[0, 1])
    conv_modules = [multi_conv, identity_module]
    activation_modules = [
        identity_module, lambda x, _: tf.nn.relu(x), lambda x, _: tf.nn.elu(x),
        lambda x, _: tf.tanh(x)
    ]
    norm_modules = [identity_module, layernorm_module, noamnorm_module]
    binary_modules = [
        first_binary_module, second_binary_module, sum_binary_module,
        shakeshake_binary_module
    ]
    inputs = features["inputs"]

    def run_unary(x, name):
      """A single step of unary modules."""
      x_shape = x.get_shape()
      with tf.variable_scope(name):
        with tf.variable_scope("norm"):
          x = run_unary_modules(norm_modules, x, hparams)
          x.set_shape(x_shape)
        with tf.variable_scope("activation"):
          x = run_unary_modules(activation_modules, x, hparams)
          x.set_shape(x_shape)
        with tf.variable_scope("conv"):
          x = run_unary_modules(conv_modules, x, hparams)
          x.set_shape(x_shape)
      return tf.nn.dropout(x, 1.0 - hparams.dropout), batch_deviation(x)

    cur1, cur2, cur3, extra_loss = inputs, inputs, inputs, 0.0
    cur_shape = inputs.get_shape()
    for i in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % i):
        cur1, loss1 = run_unary(cur1, "unary1")
        cur2, loss2 = run_unary(cur2, "unary2")
        cur3, loss3 = run_unary(cur2, "unary3")
        extra_loss += (loss1 + loss2 + loss3) / float(hparams.num_hidden_layers)
        with tf.variable_scope("binary1"):
          next1 = run_binary_modules(binary_modules, cur1, cur2, hparams)
          next1.set_shape(cur_shape)
        with tf.variable_scope("binary2"):
          next2 = run_binary_modules(binary_modules, cur1, cur3, hparams)
          next2.set_shape(cur_shape)
        with tf.variable_scope("binary3"):
          next3 = run_binary_modules(binary_modules, cur2, cur3, hparams)
          next3.set_shape(cur_shape)
        cur1, cur2, cur3 = next1, next2, next3

    anneal = common_layers.inverse_exp_decay(hparams.anneal_until)
    extra_loss *= hparams.batch_deviation_loss_factor * anneal
    return cur1, extra_loss
def cycle_vae_gan_internal(inputs, targets, _, hparams):
    """Cycle GAN, main step used for training."""
    with tf.variable_scope("cycle_vae_gan"):
        # Embed inputs and targets.
        inputs_orig, targets_orig = tf.to_int32(inputs), tf.to_int32(targets)
        k = 2 ** hparams.num_compress_steps
        inputs_orig, targets_orig = common_layers.pad_to_same_length(
            inputs_orig, targets_orig, final_length_divisible_by=k)
        inputs = common_layers.embedding(
            inputs_orig, hparams.vocab_size, hparams.hidden_size, "embed")
        targets = common_layers.embedding(
            targets_orig, hparams.vocab_size, hparams.hidden_size,
            "embed", reuse=True)

        # Split the batch into input-input and target-target parts.
        inputs1, _ = split_on_batch(inputs)
        _, targets2 = split_on_batch(targets)

        # Input-input part.
        inp1_back, kl_loss1, inp1_mu, inp1_log_sigma = transformer_vae.vae_compress(
            inputs1, None, hparams, "inp2hyp", "hyp2inp")
        inp1_hyp = tf.concat([inp1_mu, inp1_log_sigma], axis=3)

        # Target-target part.
        tgt2_back, kl_loss2, tgt2_mu, tgt2_log_sigma = transformer_vae.vae_compress(
            targets2, None, hparams, "tgt2hyp", "hyp2tgt")
        tgt2_hyp = tf.concat([tgt2_mu, tgt2_log_sigma], axis=3)

        # Reconstruction losses.
        inp1_orig, _ = split_on_batch(inputs_orig)
        _, tgt2_orig = split_on_batch(targets_orig)
        inp1_loss = reconstruct_loss(
            inp1_back, tf.squeeze(inp1_orig, axis=3), hparams)
        tgt2_loss = reconstruct_loss(
            tgt2_back, tf.squeeze(tgt2_orig, axis=3), hparams, reuse=True)

        # Discriminator loss.
        dloss = discriminate_loss(inp1_hyp, tgt2_hyp, False, hparams, "dloss")

        # Reconstruct targets from inputs.
        tgt, _, _, _ = transformer_vae.vae_compress(
            inputs, None, hparams, "inp2hyp", "hyp2tgt", reuse=True)
        tgt = tf.layers.dense(tgt, hparams.vocab_size, name="softmax",
                              reuse=True)
        # We use the reconstruction only for tracking progress, no gradients here!
        tgt = tf.stop_gradient(tf.expand_dims(tgt, axis=2))

        kl_rev_decay = common_layers.inverse_exp_decay(hparams.kl_warmup_steps)
        losses = {"input_input": hparams.cycle_loss_multiplier * inp1_loss,
                  "target_target": hparams.cycle_loss_multiplier * tgt2_loss,
                  "input_kl": kl_loss1 * kl_rev_decay * 15.0,
                  "target_kl": kl_loss2 * kl_rev_decay * 15.0,
                  "discriminator": dloss}
        return tgt, losses
Example #15
0
def run_unary_modules_basic(modules, cur, hparams):
  """Run unary modules."""
  selection_weights = create_selection_weights(
      "selection",
      "softmax",
      shape=[len(modules)],
      inv_t=100.0 * common_layers.inverse_exp_decay(
          hparams.anneal_until, min_value=0.01))
  all_res = [modules[n](cur, hparams) for n in xrange(len(modules))]
  all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
  res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1])
  return tf.reduce_sum(res, axis=0)
Example #16
0
def run_unary_modules_basic(modules, cur, hparams):
    """Run unary modules."""
    selection_weights = create_selection_weights(
        "selection",
        "softmax",
        shape=[len(modules)],
        inv_t=100.0 *
        common_layers.inverse_exp_decay(hparams.anneal_until, min_value=0.01))
    all_res = [modules[n](cur, hparams) for n in xrange(len(modules))]
    all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
    res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1])
    return tf.reduce_sum(res, axis=0)
def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin", simple=False):
    """Mix starting with x2, mixing mixing, going towards x1."""
    if mode == "lin":
        alpha_p = common_layers.inverse_lin_decay(steps)
    else:
        alpha_p = common_layers.inverse_exp_decay(steps)
    alpha_p = alpha_p * (max_prob - min_prob) + min_prob
    if simple:
        return alpha_p * x1 + (1.0 - alpha_p) * x2
    alpha = tf.random_uniform(tf.shape(x1))
    alpha = tf.to_float(tf.less(alpha, alpha_p))
    return alpha * x1 + (1.0 - alpha) * x2
Example #18
0
def vae_transformer_internal(inputs, targets, target_space, hparams):
    """VAE Transformer, main step used for training."""
    with tf.variable_scope("vae_transformer"):
        is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
        # Prepare inputs, targets, and k.
        inputs = common_layers.flatten4d3d(inputs)
        input_len = tf.shape(inputs)[1]  # Double input size to cover targets.
        inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]])
        inputs.set_shape([None, None, hparams.hidden_size])
        targets = common_layers.flatten4d3d(targets)
        k = 2**hparams.num_compress_steps
        inputs, targets = common_layers.pad_to_same_length(
            inputs, targets, final_length_divisible_by=k)
        inputs = encode(inputs, target_space, hparams, "input_enc")

        # Dropout targets or swap for zeros 5% of the time.
        targets_nodrop = targets
        max_prestep = hparams.kl_warmup_steps
        prob_targets = 0.95 if is_training else 1.0
        targets_dropout_max = common_layers.inverse_lin_decay(
            max_prestep) - 0.01
        targets = dropmask(targets, targets_dropout_max * 0.7, is_training)
        targets = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
                          lambda: targets, lambda: tf.zeros_like(targets))
        targets = targets_nodrop

        # Compress and vae.
        z = tf.get_variable("z", [hparams.hidden_size])
        z = tf.reshape(z, [1, 1, 1, -1])
        z = tf.tile(z, [tf.shape(inputs)[0], 1, 1, 1])

        z = attend(z, inputs, hparams, "z_attendsi")
        z = ffn(z, hparams, "zff2")
        z = attend(z, targets, hparams, "z_attendst2")
        z = ffn(z, hparams, "zff3")
        z, kl_loss, _, _ = vae(z, hparams, name="vae")
        z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")

        # z, kl_loss, _, _ = vae_compress(
        #     tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2),
        #     hparams, "vae_compress", "vae_decompress")

        decoder_in = tf.squeeze(z, axis=2) + tf.zeros_like(targets)
        (decoder_input, decoder_self_attention_bias) = (
            transformer.transformer_prepare_decoder(decoder_in, hparams))
        ret = transformer.transformer_decoder(decoder_input, inputs,
                                              decoder_self_attention_bias,
                                              None, hparams)

        kl_loss *= common_layers.inverse_exp_decay(int(
            max_prestep * 1.5)) * 5.0
        losses = {"kl": kl_loss}
        return tf.expand_dims(ret, axis=2), losses
Example #19
0
def vae_transformer_internal(inputs, targets, target_space, hparams):
    """VAE Transformer, main step used for training."""
    with tf.variable_scope("vae_transformer"):
        # Prepare inputs, targets, and k.
        inputs = common_layers.flatten4d3d(inputs)
        input_len = tf.shape(inputs)[1]  # Double input size to cover targets.
        inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]])
        inputs.set_shape([None, None, hparams.hidden_size])
        targets = common_layers.flatten4d3d(targets)
        k = 2**hparams.num_compress_steps
        inputs, targets = common_layers.pad_to_same_length(
            inputs, targets, final_length_divisible_by=k)
        inputs, ed_bias = encode(inputs, target_space, hparams, "input_enc")

        # Compress and vae.
        z, kl, r = vae_compress(tf.expand_dims(targets, axis=2),
                                tf.expand_dims(inputs, axis=2), ed_bias,
                                hparams, "vae_compress", "vae_decompress")
        kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5))
        r *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 2.0))
        losses = {"kl": kl, "reconstruction": r}
        return z, losses
Example #20
0
def scheduled_sampling(hparams, problem_hparams, dp, sharded_logits, losses,
                       sharded_features, transformed_features, model):
  """Scheduled sampling."""
  target_modality = problem_hparams.target_modality

  def sample(x):
    """Multinomial sampling from a n-dimensional tensor."""
    vocab_size = target_modality.top_dimensionality
    samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1)
    reshaped_samples = tf.reshape(samples, common_layers.shape_list(x)[:-1])
    return tf.to_int32(reshaped_samples)

  def mix_gold_sampled(gold_targets, sampled_targets):
    return tf.where(
        tf.less(
            tf.random_uniform(common_layers.shape_list(sampled_targets)),
            hparams.scheduled_sampling_gold_mixin_prob), gold_targets,
        sampled_targets)

  def sampled_results():
    """Generate scheduled sampling results."""
    sampled_targets = dp(sample, sharded_logits)
    new_targets = dp(mix_gold_sampled, sharded_features["targets"],
                     sampled_targets)
    new_features = transformed_features
    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      with tf.variable_scope(target_modality.name):
        new_features["targets"] = target_modality.targets_bottom_sharded(
            new_targets, dp)
      with tf.variable_scope("body"):
        body_outputs, losses = model.model_fn_sharded(new_features)
        if not isinstance(losses, dict):  # If it's a single extra loss.
          losses = {"extra": losses}
      with tf.variable_scope(target_modality.name):
        new_sharded_logits = target_modality.top_sharded(
            body_outputs, sharded_features["targets"], dp)
        if "training" not in losses:
          training_loss = target_modality.loss_sharded(
              sharded_logits, sharded_features["targets"], dp)
          training_loss *= problem_hparams.loss_multiplier
          losses["training"] = training_loss
    return new_sharded_logits, losses

  # Run the above conditionally.
  prob = hparams.scheduled_sampling_prob
  prob *= common_layers.inverse_exp_decay(
      hparams.scheduled_sampling_warmup_steps, min_value=0.001)
  sharded_logits, losses = tf.cond(
      tf.less(tf.random_uniform([]), prob), sampled_results,
      lambda: (sharded_logits, losses))
  return sharded_logits, losses
def dae(x, hparams, name):
    with tf.variable_scope(name):
        m = tf.layers.dense(x, hparams.v_size, name="mask")
        logsm = tf.nn.log_softmax(m)
        # Gumbel-softmax sample.
        gumbel_samples = gumbel_sample(tf.shape(m))
        steps = hparams.kl_warmup_steps
        gumbel_samples *= common_layers.inverse_exp_decay(steps) * 0.1
        temperature = 1.2 - common_layers.inverse_lin_decay(steps)
        s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
        m = tf.nn.softmax(m)
        kl = -tf.reduce_max(logsm, axis=-1)
        tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
        return m, s, tf.reduce_mean(kl)
Example #22
0
def run_unary_modules_sample(modules, cur, hparams, k):
  """Run modules, sampling k."""
  selection_weights = create_selection_weights(
      "selection", ("softmax_topk", k),
      shape=[len(modules)],
      inv_t=100.0 * common_layers.inverse_exp_decay(
          hparams.anneal_until, min_value=0.01))
  all_res = [
      tf.cond(
          tf.less(selection_weights.normalized[n], 1e-6),
          lambda: tf.zeros_like(cur),
          lambda i=n: modules[i](cur, hparams)) for n in xrange(len(modules))
  ]
  all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
  res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1])
  return tf.reduce_sum(res, axis=0)
Example #23
0
def run_unary_modules_sample(modules, cur, hparams, k):
    """Run modules, sampling k."""
    selection_weights = create_selection_weights(
        "selection", ("softmax_topk", k),
        shape=[len(modules)],
        inv_t=100.0 *
        common_layers.inverse_exp_decay(hparams.anneal_until, min_value=0.01))
    all_res = [
        tf.cond(tf.less(selection_weights.normalized[n], 1e-6),
                lambda: tf.zeros_like(cur),
                lambda i=n: modules[i](cur, hparams))
        for n in xrange(len(modules))
    ]
    all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
    res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1])
    return tf.reduce_sum(res, axis=0)
Example #24
0
def ae_compress(x, is_2d, hparams, name, reuse=None):
    """Compress, then AE."""
    with tf.variable_scope(name, reuse=reuse):
        cur = compress(x, None, is_2d, hparams, "compress")
        # Convolve and ReLu to get state.
        cur = common_layers.conv_block(cur,
                                       hparams.hidden_size, [((1, 1), (1, 1))],
                                       name="mid_conv")
        means_size = hparams.z_size if hparams.do_vae else hparams.v_size
        means = tf.get_variable("z_to_dense",
                                [means_size, hparams.hidden_size])
        if hparams.do_vae:
            if hparams.bit_vae:
                hot, loss = bit_vae(cur, hparams, "bvae")
            else:
                hot, loss, _, _ = vae(cur, hparams.z_size, "vae")
            # Do a second level vae with some probability.
            if hparams.z_size2 > 0:
                prob_z2 = common_layers.inverse_exp_decay(
                    hparams.startup_steps * 2) * 0.8
                if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
                    prob_z2 = 1.0

                def vae2():
                    hot2, loss2, _, _ = vae(hot, hparams.z_size2, "vae2")
                    ret = tf.layers.dense(hot2, hparams.z_size)
                    return mix(ret, hot, hparams.startup_steps * 2), loss2

                hot, loss2 = tf.cond(tf.less(tf.random_uniform([]), prob_z2),
                                     vae2, lambda: (hot, tf.constant(0.0)))
                loss += loss2 * 0.1
            return cur, hot, loss
        if hparams.use_gumbel_softmax:
            _, hot, loss = dae(cur, hparams, "dae")
            return cur, hot, loss
        # Using k-means part. L2-normalizing to use fast cosine distance.
        cur = mix(tf.nn.l2_normalize(cur, dim=3),
                  cur,
                  hparams.startup_steps // 3,
                  mode="exp",
                  simple=True)
        cur_n = hparams.kmeans_lr_factor * cur
        cur_n += (1.0 - hparams.kmeans_lr_factor) * tf.stop_gradient(cur)
        hot, loss = kmeans(cur_n, means, hparams, name="kmeans")
        # We need a linear layer to undo the l2-normalization.
        cur = tf.layers.dense(cur, hparams.hidden_size, name="unnormalize")
        return cur, hot, loss
Example #25
0
def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None):
    """Decompress from z, leaking from ae."""
    with tf.variable_scope(name + "_decompress", reuse=reuse):
        if hparams.use_gumbel_softmax or hparams.do_vae:
            # Leak at the beginning to help train.
            z = mix(z, ae, hparams.startup_steps)
        else:
            # Gradients flow to ae while the value is z.
            z = tf.stop_gradient(z) + ae - tf.stop_gradient(ae)
        # Leak during training to keep the full dense autoencoder.
        prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8
        prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 1.0
        z = tf.cond(tf.less(tf.random_uniform([]), prob_z), lambda: z,
                    lambda: ae)

        # Dropout for better autoencoding.
        z = tf.nn.dropout(z, keep_prob=1.0 - hparams.z_dropout)

        # Decompress.
        d = z
        k = (3, 3) if is_2d else (3, 1)
        for i in xrange(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, k, hparams, "decompress_rc_%d" % j)
            d = decompress_step(d, None, hparams, i > 0, is_2d,
                                "decompress_%d" % j)

        # Autoregressive part.
        if hparams.decode_autoregressive:
            k = 2**(hparams.num_compress_steps * (2 if is_2d else 1))
            x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size])
            x_batch = tf.stop_gradient(x_batch)
            z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size])
            d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size])
            dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams,
                               "dar")
        else:  # For non-autoregressive.
            dec_batch = d
        z = tf.reshape(
            dec_batch,
            [-1, tf.shape(x)[1],
             tf.shape(x)[2], hparams.hidden_size])
        if is_2d:
            z = tf.layers.dense(z, hparams.hidden_size * 3)
    return z
def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None):
    """Decompress from z, leaking from ae."""
    with tf.variable_scope(name + "_decompress", reuse=reuse):
        # Leak at the beginning to help train.
        z = mix(z, ae, hparams.startup_steps)
        prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8
        prob_z = prob_z if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        z = tf.cond(tf.less(tf.random_uniform([]), prob_z), lambda: z,
                    lambda: ae)

        # Dropout for better autoencoding.
        z = tf.nn.dropout(z, keep_prob=1.0 - hparams.z_dropout)

        # Decompress.
        d = z
        k = (3, 3) if is_2d else (3, 1)
        for i in xrange(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, k, hparams, "decompress_rc_%d" % j)
            d = decompress_step(d, None, hparams, i > 0, is_2d,
                                "decompress_%d" % j)

        # Autoregressive part.
        if not is_2d:  # Currently we don't do it autoregressively for 2d problems.
            k = 2**(hparams.num_compress_steps * (2 if is_2d else 1))
            z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size])
            x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size])
            d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size])
            dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
        else:  # For non-autoregressive.
            dec_batch = d
        z = tf.reshape(
            dec_batch,
            [-1, tf.shape(x)[1],
             tf.shape(x)[2], hparams.hidden_size])
        if is_2d:
            z = tf.layers.dense(z, hparams.hidden_size * 3)
    return z
Example #27
0
def gumbel_softmax(x,
                   name,
                   z_size,
                   mode,
                   softmax_k=0,
                   kl_warmup_steps=150000,
                   summary=True):
    """Gumbel softmax discretization bottleneck.

  Args:
    x: Input to the discretization bottleneck.
    name: Name for the bottleneck scope.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding function, discrete code and loss.
  """
    with tf.variable_scope(name):
        m = tf.layers.dense(x, 2**z_size, name="mask")
        if softmax_k > 0:
            m, kl = top_k_softmax(m, softmax_k)
            return m, m, 1.0 - tf.reduce_mean(kl)
        logsm = tf.nn.log_softmax(m)

        # Gumbel-softmax sample.
        gumbel_samples = gumbel_sample(common_layers.shape_list(m))
        steps = kl_warmup_steps
        gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
        temperature = 1.2 - common_layers.inverse_lin_decay(steps)

        # 10% of the time keep reasonably high temperature to keep learning.
        temperature = tf.cond(
            tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
            lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
        s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
        m = tf.nn.softmax(m)
        kl = -tf.reduce_max(logsm, axis=-1)

        if summary:
            tf.summary.histogram("max-log", tf.reshape(kl, [-1]))

        # Calculate the argmax and construct hot vectors.
        maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
        maxvhot = tf.stop_gradient(tf.one_hot(maxvec, 2**z_size))

        # Add losses that prevent too few being used.
        distrib = tf.reshape(logsm, [-1, 2**z_size]) * maxvhot
        d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
        d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
        d_dev = -tf.reduce_mean(d_variance)
        ret = s

        if mode != tf.contrib.learn.ModeKeys.TRAIN:
            ret = tf.reshape(maxvhot,
                             common_layers.shape_list(s))  # Just hot @eval.
        return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
Example #28
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {
        "extra": tf.constant(0.0),
        "latent_pred": tf.constant(0.0),
        "neg_q_entropy": tf.constant(0.0)
    }
    if hparams.do_ae:
        # flatten here
        original_targets = targets
        original_targets_shape = tf.shape(original_targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            if inputs is not None:
                max_targets_len_from_inputs = tf.concat([inputs, inputs],
                                                        axis=1)
            else:
                max_targets_len_from_inputs = targets
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        if hparams.word_shuffle:
            tf.logging.info("Using word shuffle with rate = {}".format(
                hparams.word_shuffle))
            targets_idx = tf.range(start=0,
                                   limit=common_layers.shape_list(targets)[1],
                                   delta=1)
            targets_idx = tf.to_float(targets_idx)
            noise = tf.random_uniform(
                shape=common_layers.shape_list(targets_idx),
                minval=0,
                maxval=1 + hparams.word_shuffle)
            targets_idx += noise
            permutation = contrib.framework().argsort(targets_idx)
            targets_permuted = tf.gather(targets, indices=permutation, axis=1)
            targets = targets_permuted
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        # Add positional information
        targets_shape = common_layers.shape_list(targets)
        targets = tf.reshape(
            targets, [targets_shape[0], targets_shape[1], targets_shape[3]])
        targets = common_attention.add_positional_embedding(
            targets, hparams.max_length, name="targets_position")
        targets = tf.reshape(targets, shape=targets_shape)
        if hparams.word_dropout:
            mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                                     minval=0.0,
                                     maxval=1.0)
            targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                                     tf.zeros_like(targets))
        else:
            targets_noisy = targets

        targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
                hparams.bottleneck(inputs=targets_c,
                                   filter_size=hparams.compress_filter_size,
                                   mode=hparams.mode,
                                   name="vc"))
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)

                # Scale by latent dimension for summary so we can compare across
                # batches.
                if _DO_SUMMARIES:
                    tf.summary.scalar("latent_pred_loss_mean",
                                      tf.reduce_mean(latent_pred_loss))
                if hparams.sum_over_latents:
                    latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

                losses["latent_pred"] = tf.reduce_mean(
                    latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
                losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    tf.squared_difference(inputs_c, targets_c)) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _, _ = hparams.bottleneck(
                            inputs=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            mode=hparams.mode,
                            name="vc")
                    return bn

                inputs_c = bn_inputs()
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _, _ = hparams.bottleneck(
                    inputs=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    mode=hparams.mode,
                    name="vc")
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed, _ = hparams.bottleneck(
                    inputs=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        d_shape = common_layers.shape_list(d)
        d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
        d = common_attention.add_positional_embedding(d,
                                                      hparams.max_length,
                                                      name="latents_position")
        d = tf.reshape(d, shape=d_shape)

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if inputs is not None and hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        else:
            targets = d

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)

    # res was generated from padded targets, which means it has some extra
    # elements. These can cause shape problems when computing loss with respect to
    # the original (unpadded) targets. So we remove their extra elements here.
    res = res[:, :original_targets_shape[1], :, :]

    data_dim = common_layers.shape_list(res)[1]
    latent_dim = common_layers.shape_list(targets_c)[1]
    return res, losses, cache, data_dim, latent_dim
    def body(self, features):
        hparams = self.hparams
        is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT

        # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on
        # using the default (a bit strange) video modality - we should change that.

        # Split inputs and targets into lists.
        input_frames = tf.unstack(features["inputs"], axis=1)
        target_frames = tf.unstack(features["targets"], axis=1)
        all_frames = input_frames + target_frames
        if "input_action" in features:
            input_actions = list(
                tf.split(features["input_action"],
                         hparams.video_num_input_frames,
                         axis=1))
            target_actions = list(
                tf.split(features["target_action"],
                         hparams.video_num_target_frames,
                         axis=1))
            all_actions = input_actions + target_actions

        orig_frame_shape = common_layers.shape_list(all_frames[0])

        # Run a number of steps.
        res_frames, sampled_frames, sampled_frames_raw = [], [], []
        if "target_reward" in features:
            res_rewards, extra_loss = [], 0.0
        sample_prob = common_layers.inverse_exp_decay(
            hparams.scheduled_sampling_warmup_steps)
        sample_prob *= hparams.scheduled_sampling_prob
        for i in range(hparams.video_num_target_frames):
            cur_frames = all_frames[i:i + hparams.video_num_input_frames]
            features["inputs"] = tf.concat(cur_frames, axis=-1)
            features["cur_target_frame"] = all_frames[
                i + hparams.video_num_input_frames]
            if "input_action" in features:
                cur_actions = all_actions[i:i + hparams.video_num_input_frames]
                features["input_action"] = tf.concat(cur_actions, axis=1)

            # Run model.
            with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
                if "target_reward" not in features:
                    res_frame = self.body_single(features)
                else:
                    res_dict, res_extra_loss = self.body_single(features)
                    extra_loss += res_extra_loss
                    res_frame = res_dict["targets"]
                    res_reward = res_dict["target_reward"]
                    res_rewards.append(res_reward)
            res_frames.append(res_frame)

            # Only for Softmax loss: sample frame so we can keep iterating.
            sampled_frame_raw = self.get_sampled_frame(res_frame)
            sampled_frames_raw.append(sampled_frame_raw)
            # TODO(lukaszkaiser): this should be consistent with modality.bottom()
            sampled_frame = common_layers.standardize_images(sampled_frame_raw)
            sampled_frames.append(sampled_frame)

            if is_predicting:
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

            # Scheduled sampling during training.
            if (hparams.scheduled_sampling_prob > 0.0 and self.is_training):
                do_sample = tf.less(tf.random_uniform([orig_frame_shape[0]]),
                                    sample_prob)
                orig_frame = all_frames[i + hparams.video_num_input_frames]
                sampled_frame = tf.where(do_sample, sampled_frame, orig_frame)
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

        # Concatenate results and return them.
        frames = tf.stack(res_frames, axis=1)

        if "target_reward" not in features:
            return frames
        rewards = tf.concat(res_rewards, axis=1)
        return {"targets": frames, "target_reward": rewards}, extra_loss
Example #30
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.vocab_size["targets"]
    if hasattr(self._hparams, "vocab_divisor"):
      vocab_size += (-vocab_size) % self._hparams.vocab_divisor
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(
            tf.squared_difference(x_stop, b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
Example #31
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """Auto-encoder using a Transformer decoder and a prior over latent sequences.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None.
    targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2
      dimensions denoting sequence length.
    target_space: int. Used for encoding inputs under a target space id.
    hparams: HParams.
    cache: Tensor of shape [batch, length] or None.
    predict_mask: Tensor masking whether to use gold targets or predictions.

  Returns:
    decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting
      pre-logit activations. After a transformation (`top` in `T2TModel`), it is
      used with targets to compute the "training" (reconstruction) loss.
    losses: dict of str to Tensors. There are three loss terms: "extra",
      "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter
      two are Tensors of shape [batch].
    cache: Tensor of shape [batch, length], either the same as cache, or newly
      computed if the cache input is None.
  """
    original_targets_shape = common_layers.shape_list(targets)
    batch_size = original_targets_shape[0]
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    ed_attention_bias = None
    if inputs is not None:
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, name="input_encoder")

    losses = {"extra": 0., "extra_loss": 0., "latent_pred": 0.}
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        targets_compressed = compress_fn(targets, hparams, name="compress")

        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            scale = common_layers.inverse_exp_decay(hparams.startup_steps)
        else:
            scale = 1.0
        scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale))

        latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer(
            targets_compressed, hparams)
        extra_loss = scale * tf.reduce_mean(extra_loss)

        _, latents_pred_loss = latent_prediction_model(inputs,
                                                       ed_attention_bias,
                                                       latents_discrete,
                                                       latents_dense,
                                                       hparams,
                                                       name="latent_pred")
        latent_time = tf.less(hparams.mask_startup_steps,
                              tf.to_int32(tf.train.get_global_step()))
        latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss)
        latents_pred_loss *= tf.to_float(latent_time)

        # Apply dropout noise for each data point and time step.
        latents_dense_shape = common_layers.shape_list(latents_dense)
        latents_dense = tf.nn.dropout(
            latents_dense,
            keep_prob=1 - hparams.latent_dropout,
            noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1])

        # TODO(trandustin): Can we combine extra and extra_loss?
        losses = {
            "extra": 0.,
            "extra_loss": extra_loss,
            "latent_pred": latents_pred_loss
        }
    else:
        # Set the latent length, which is num_latents times the number of latent
        # pixels. The number of latent pixels is determined by a compression factor
        # on the number of image pixels.
        latent_len = (
            (hparams.img_len * hparams.img_len * hparams.num_latents) /
            (2**hparams.num_compress_steps))
        _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams)
        latents_dense = tf.zeros(
            [batch_size, latent_len, 1, hparams.hidden_size])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs,
                                          ed_attention_bias, embed_fn, hparams)
        cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
        latents_dense = embed_fn(cache_one_hot, hparams.hidden_size)

    if len(original_targets_shape) == 4:
        compressed_img_len = (hparams.img_len //
                              2**(hparams.num_compress_steps // 2))
        latents_dense = tf.reshape(latents_dense, [
            batch_size, compressed_img_len, compressed_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    latents_dense = decompress_fn(latents_dense, hparams, name="decompress")
    latents_dense = tf.reshape(
        latents_dense,
        [-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        else:
            masking = common_layers.inverse_exp_decay(
                hparams.mask_startup_steps)
        targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        latents_dense = mask * targets + (1.0 - mask) * latents_dense

    latents_dense = tf.reshape(latents_dense, original_targets_shape)
    if hparams.decode_autoregressive:
        decoder_output = transformer_image_decoder(latents_dense,
                                                   inputs,
                                                   ed_attention_bias,
                                                   hparams,
                                                   name="decoder")
    else:
        decoder_output = latents_dense
    return decoder_output, losses, cache
Example #32
0
def gumbel_softmax_discrete_bottleneck(x,
                                       bottleneck_bits,
                                       beta=0.25,
                                       decay=0.999,
                                       epsilon=1e-5,
                                       startup_steps=15000,
                                       hard=False,
                                       summary=True):
  """VQ-VAE using Gumbel-Softmax.

  Different from `gumbel_softmax()` function as
  this function calculates the KL by using the discrete entropy
  instead of taking the argmax, and it also uses an exponential moving average
  to update the codebook while the `gumbel_softmax()` function includes no
  codebook update.

  Args:
    x: A `float`-like `Tensor` containing the latent vectors to be compared to
      the codebook, whose squared difference is used as the Gumbel-Softmax
      logits.
    bottleneck_bits: An `int` that sets the size of the bottleneck in `log_2`.
    beta: Beta factor for commitment loss (Default: 0.25).
    decay: Decay factor for exponential moving average (Default: 0.999).
    epsilon: Small value to avoid dividing by zero in EMA update
      (Default: 1e-5).
    startup_steps: Number of steps for KL warmup (Default: 25000).
    hard: When `True`, we use hard Gumbel-Softmax samples and force
      discrete latents by taking the argmax. When `False`, we use soft samples,
      which we treat as codebook weights (Default: False).
    summary: When `True`, we save histogram summaries of the KL term (Default:
      True).

  Returns:
    x_means_assignments: A `float`-like `Tensor` containing the codebook
      assignments. When `hard == True`, this is one-hot, containing the arg-max
      of the Gumbel-Softmax samples (and we use the straightthrough gradient).
      Otherwise, it contains the Gumbel-Softmax samples exactly, which are
      values from the `(K-1)`-simplex where `K` is the bottleneck size.
    loss: The loss, which is the sum of the KL between the Gumbel-Softmax and
      the uniform prior and the commitment loss multiplied by the beta factor.
      We approximate the KL by using the entropy of a categorical distribution
      instead of the Gumbel Softmax.

  """
  bottleneck_size = 2**bottleneck_bits
  x_shape = common_layers.shape_list(x)
  hidden_size = x_shape[-1]
  means, ema_means, ema_count = get_vq_bottleneck(bottleneck_size, hidden_size)
  x = tf.reshape(x, [-1, hidden_size])

  bottleneck_size = common_layers.shape_list(means)[0]
  x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
  means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
  scalar_prod = tf.matmul(x, means, transpose_b=True)
  dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod

  class_probs = tf.nn.softmax(dist)
  log_class_probs = tf.nn.log_softmax(dist)
  gumbel_samples = gumbel_sample(common_layers.shape_list(dist))
  gumbel_samples *= common_layers.inverse_exp_decay(startup_steps // 5) * 0.5
  temperature = 1.2 - common_layers.inverse_lin_decay(startup_steps)

  # 10% of the time keep reasonably high temperature to keep learning.
  temperature = tf.cond(
      tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
      lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
  gumbel_softmax_samples = tf.nn.softmax(
      (log_class_probs + gumbel_samples) / temperature)

  # Calculate KL between q and a uniform prior.
  kl = tf.reduce_sum(class_probs * (log_class_probs -
                                    tf.log(1.0/bottleneck_size)), -1)
  if summary:
    tf.summary.histogram("KL", tf.reshape(kl, [-1]))

  # Straight-through gradient estimation when we're using hard assignments.
  if hard:
    x_means_idx = tf.reshape(tf.argmax(gumbel_softmax_samples, axis=-1), [-1])
    x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
    x_means_assignments = gumbel_softmax_samples + tf.stop_gradient(
        x_means_hot - gumbel_softmax_samples)
  else:
    x_means_assignments = gumbel_softmax_samples
  x_means_assignments_flat = tf.reshape(
      x_means_assignments, [-1, bottleneck_size])
  x_means = tf.matmul(x_means_assignments_flat, means)
  commitment_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))

  # Update the ema variables.
  updated_ema_count = moving_averages.assign_moving_average(
      ema_count,
      tf.reduce_sum(
          tf.reshape(x_means_assignments, shape=[-1, bottleneck_size]), axis=0),
      decay,
      zero_debias=False)

  dw = tf.matmul(x_means_assignments, x, transpose_a=True)
  updated_ema_means = tf.identity(moving_averages.assign_moving_average(
      ema_means, dw, decay, zero_debias=False))
  n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
  updated_ema_count = (
      (updated_ema_count + epsilon) / (n + bottleneck_size * epsilon) * n)
  updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
  with tf.control_dependencies([commitment_loss]):
    update_means = means.assign(updated_ema_means)
    with tf.control_dependencies([update_means]):
      loss = beta * commitment_loss

  # Add KL loss.
  loss += tf.reduce_mean(kl)

  x_means_assignments = tf.reshape(
      x_means_assignments, x_shape[:-1] + [bottleneck_size])
  return x_means_assignments, loss
Example #33
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None):
    """Main step used for training."""
    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)

    targets, _ = common_layers.pad_to_same_length(
        targets,
        max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, hparams, "compress")

    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        # Compress and bottleneck.
        latents_discrete_hot, extra_loss = vq_discrete_bottleneck(
            x=targets_c, hparams=hparams)
        latents_dense = vq_discrete_unbottleneck(latents_discrete_hot, hparams)
        latents_discrete = tf.argmax(latents_discrete_hot, axis=-1)
        tf.summary.histogram("codes",
                             tf.reshape(latents_discrete[:, 0, :], [-1]))
        pc = common_layers.inverse_exp_decay(hparams.startup_steps)
        pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        cond = tf.less(tf.random_uniform([batch_size]), pc)
        latents_dense = tf.where(cond, latents_dense, targets_c)
        losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))

        # Extra loss predicting latent code from input. Discrete only.
        latents_pred = decode_transformer(inputs_ex, ed_ex, latents_dense,
                                          hparams, "extra")
        latent_pred_loss = get_latent_pred_loss(latents_pred,
                                                latents_discrete_hot, hparams)
        losses["latent_pred"] = tf.reduce_mean(latent_pred_loss *
                                               tf.to_float(cond))
    else:
        latent_len = common_layers.shape_list(targets_c)[1]
        embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams)
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs_ex, ed_ex,
                                          embed, hparams)
        latents_dense = embed(
            tf.one_hot(cache, depth=2**hparams.bottleneck_bits))

    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Decompressing the dense latents
    for i in range(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, "decompress_%d" % j)

    res = decode_transformer(inputs, ed, targets, hparams, "decoder")
    # We'll start training the extra model of latents after mask_startup_steps.
    nonlatent_steps = hparams.mask_startup_steps
    latent_time = tf.less(nonlatent_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
Example #34
0
def ae_transformer_internal(inputs, targets, target_space, hparams,
                            beam_size, cache=None, predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  orig_targets = targets
  batch_size = tf.shape(orig_targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
  else:
    ed = None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc")
      if _DO_SUMMARIES:
        tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([]), pc)
      t_c = tf.cond(cond, lambda: t_c, lambda: targets_c)
      losses["extra"] = vc_loss * tf.to_float(cond)
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        t_pred = decode_transformer(
            inputs, ed, tf.stop_gradient(t_c), hparams, "extra")
        t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
        losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=t_bit, logits=t_pred)
        losses["latent_pred"] = tf.reduce_mean(
            losses["latent_pred"]) * 0.5 * tf.to_float(cond)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        targets_rand = tf.random_uniform(tf.shape(targets_c))
        t_c, _, _, _ = bottleneck(targets_rand, hparams, 2*2048, "vc")
      else:
        latent_len = tf.shape(targets_c)[1]
        _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc")
        t_c = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams)
          cache = cache[0, :, :]
          cache = tf.reshape(cache, [1, latent_len, 1])
          cache = tf.tile(cache, [beam_size, 1, 1])
        t_c = embed(cache)
    # Postprocess.
    d = t_c
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :tf.shape(t_c)[1] + 1, :, :]
    t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(100000)
      masking *= common_layers.inverse_exp_decay(25000)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * 0.3
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)
      for i in xrange(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
      targets = mask * targets + (1.0 - mask) * d
    targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  if hparams.do_ae:
    res = res[:, tf.shape(t_c)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        return residual_conv(res, 1, (5, 1), hparams, "refine")
      all_masked = tf.less(tf.reduce_sum(mask), 0.1)
      res = tf.cond(all_masked, refine_res, lambda: res)
  return res, losses, cache
def bottleneck(x, hparams, filter_size, name):
  """Bottleneck."""
  def embed(x):
    """Embedding function; must be compatible with the code later."""
    with tf.variable_scope(name, reuse=True):
      if hparams.bottleneck_kind == "semhash":
        c = int_to_bit(x, z_size)
        h1a = tf.layers.dense(c, filter_size, name="vch1a")
        h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
        h1 = h1a + h1b
      elif hparams.bottleneck_kind == "gumbel-softmax":
        hot = tf.one_hot(x, hparams.v_size)
        h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
      elif hparams.bottleneck_kind == "vq-vae":
        means = tf.get_variable(name="means",
                                shape=[hparams.v_size, hparams.hidden_size])
        h1 = tf.gather(means, x)
      elif hparams.bottleneck_kind == "rounding":
        h1 = x

      h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
      return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")

  with tf.variable_scope(name):
    z_size = hparams.z_size
    l = tf.constant(0.0)
    if hparams.bottleneck_kind == "dense":
      c = tf.layers.dense(x, z_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "vae":
      c, l, _, _ = vae(x, z_size, "vae")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "semhash":
      c = tf.layers.dense(x, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      if _DO_SUMMARIES:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
        dev = hparams.noise_dev
        noise = tf.truncated_normal(common_layers.shape_list(c),
                                    mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2)
      pd *= hparams.d_mix
      pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(tf.less(tf.random_uniform(
          [common_layers.shape_list(y)[0]]), pd), y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    if hparams.bottleneck_kind == "gumbel-softmax":
      _, hot, l = dae(x, hparams, name)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
    if hparams.bottleneck_kind == "vq-vae":
      means = tf.get_variable(name="means", shape=[hparams.v_size,
                                                   hparams.hidden_size])
      x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans")
      h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x)
      c = tf.argmax(x_means_hot, axis=-1)
    if hparams.bottleneck_kind == "rounding":
      h = tf.layers.dense(x, 1, name="vcc")

      # Make h between 0 and 1
      h = tf.sigmoid(h)

      # Multiply by z_size to get it between [0, z_size]
      h *= hparams.v_size

      # Use the rounding bottleneck
      h1 = h + tf.stop_gradient(tf.round(h) - h)
      c = tf.squeeze(tf.round(h), axis=-1)
      c = tf.to_int32(c)
    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
    res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res, c, l, embed
Example #36
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind="dvq",
                        num_blocks=2,
                        num_residuals=1,
                        reshape_method="slice",
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        soft_em=False,
                        num_samples=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True):
    """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector
      quantization (Default: 2).
    num_residuals: Number of residual units used to compute nearest
      neighbors (Default: 1).
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples for soft EM (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
    tf.logging.info("Shape of x = {}".format(common_layers.shape_list(x)))
    block_v_size = None
    if bottleneck_kind == "dvq":
        # Define the dvq parameters
        assert means is not None

        # Check block dimensions add up
        if hidden_size % num_blocks != 0:
            raise ValueError("num_blocks does not divide hidden size")

        if z_size % num_residuals != 0:
            raise ValueError(
                "num_residuals does not divide embedding table size")

        z_size_per_residual = int(z_size / num_residuals)

        if z_size_per_residual % num_blocks != 0:
            raise ValueError("num_blocks does not divide embedding table size")

        block_v_size = 2**(z_size_per_residual / num_blocks)
        block_v_size = int(block_v_size)

        # Set the reshape method corresponding to projections or slices
        if reshape_method == "slice":
            reshape_fn = partial(slice_hidden,
                                 hidden_size=hidden_size,
                                 num_blocks=num_blocks)
        elif reshape_method == "project":
            if projection_tensors is None:
                raise ValueError(
                    "Projection tensors is None for reshape_method project")
            reshape_fn = partial(project_hidden,
                                 projection_tensors=projection_tensors,
                                 hidden_size=hidden_size,
                                 num_blocks=num_blocks)
        else:
            raise ValueError("Unknown reshape_method")

        # Check if the ema settings make sense
        if ema:
            if ema_count is None:
                raise ValueError("ema_count is None but ema is True")
            if ema_means is None:
                raise ValueError("ema_means is None but ema is True")

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        l = tf.constant(0.0)
        if bottleneck_kind == "dense":
            c = tf.layers.dense(x, z_size, name="vcc")
            h1 = tf.layers.dense(c, filter_size, name="vch1")
        elif bottleneck_kind == "vae":
            c, l, _, _ = vae(x, z_size, "vae")
            h1 = tf.layers.dense(c, filter_size, name="vch1")
        elif bottleneck_kind == "semhash":
            c = tf.layers.dense(x, z_size, name="vcc")
            y_clean = common_layers.saturating_sigmoid(c)
            if summary:
                tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
            if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
                noise = tf.truncated_normal(common_layers.shape_list(c),
                                            mean=0.0,
                                            stddev=noise_dev)
                y = common_layers.saturating_sigmoid(c + noise)
            else:
                y = y_clean
            d = tf.to_float(tf.less(0.5, y))
            y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
            pd = common_layers.inverse_exp_decay(startup_steps * 2)
            pd *= discrete_mix
            pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
            c = tf.where(
                tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]),
                        pd), y_discrete, y)
            h1a = tf.layers.dense(c, filter_size, name="vch1a")
            h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
            h1 = h1a + h1b
            dx = tf.to_int32(tf.stop_gradient(d))
            c = bit_to_int(dx, z_size)
        elif bottleneck_kind == "gumbel-softmax":
            _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                       kl_warmup_steps, summary)
            c = tf.argmax(hot, axis=-1)
            h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
        elif bottleneck_kind == "dvq":
            x_reshaped = reshape_fn(x)
            x_res = x_reshaped
            x_means_hot = []
            x_means = 0
            l = 0
            for i in range(num_residuals):
                x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup(
                    x_res, means[i], num_blocks, block_v_size, random_top_k,
                    soft_em, num_samples)

                # Update the ema variables
                if ema:
                    tf.logging.info("Using EMA with beta = {}".format(beta))
                    updated_ema_count_res = moving_averages.assign_moving_average(
                        ema_count[i],
                        tf.reduce_sum(tf.reshape(
                            x_means_hot_res,
                            shape=[-1, num_blocks, block_v_size]),
                                      axis=0),
                        decay,
                        zero_debias=False)

                    dw = tf.matmul(
                        tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
                        tf.transpose(x_res, perm=[1, 0, 2]))

                    updated_ema_means_res = moving_averages.assign_moving_average(
                        ema_means[i], dw, decay, zero_debias=False)
                    n = tf.reduce_sum(updated_ema_count_res,
                                      axis=-1,
                                      keep_dims=True)
                    updated_ema_count_res = (
                        (updated_ema_count_res + epsilon) /
                        (n + 2**z_size * epsilon) * n)
                    updated_ema_means_res /= tf.expand_dims(
                        updated_ema_count_res, axis=-1)

                    with tf.control_dependencies([e_loss_res]):
                        update_means_res = tf.assign(means[i],
                                                     updated_ema_means_res)
                        with tf.control_dependencies([update_means_res]):
                            l += beta * e_loss_res
                else:
                    l += q_loss_res + beta * e_loss_res

                # Update the residuals
                x_res -= x_means_res
                x_means += x_means_res
                x_means_hot.append(x_means_hot_res)

            # Get the discrete latent representation
            x_means_hot = tf.stack(x_means_hot, axis=1)
            x_means_idx = tf.argmax(x_means_hot, axis=-1)

            # Get the binary representation
            x_means_bits = int_to_bit(
                x_means_idx,
                num_bits=int(z_size / (num_residuals * num_blocks)),
                base=2)
            shape = common_layers.shape_list(x_means_bits)
            new_shape = shape[:-2]
            new_shape[-1] = z_size
            x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
            c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

            # Adjust shape of c
            shape_x = common_layers.shape_list(x)
            new_shape = shape_x[:-1]
            c = tf.reshape(c, new_shape)

            x_means = tf.reshape(x_means, shape_x)
            x_reshaped = tf.reshape(x_reshaped, shape_x)
            h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
        else:
            raise ValueError("Unknown discretization method.")

        h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
        res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name="vcfin")

        embed_fn = partial(embed,
                           hidden_size=hidden_size,
                           z_size=z_size,
                           filter_size=filter_size,
                           name=name,
                           bottleneck_kind=bottleneck_kind,
                           num_blocks=num_blocks,
                           num_residuals=num_residuals,
                           block_v_size=block_v_size,
                           means=means)
        return res, c, l, embed_fn
Example #37
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Define losses
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    # Reshape image targets as 4d tensor.
    original_targets_shape = common_layers.shape_list(targets)
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    # Encoder decoder attention bias.
    ed_attention_bias = None

    # Input Encoder if present.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, "input_enc")

    # Encode targets to compute targets compressed.
    targets_c = compress_fn(targets, hparams, "compress")
    targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

    # Following code creates an exponentially decaying variable based on which
    # we rescale the los values.
    batch_size = common_layers.shape_list(targets_c)[0]
    pc = common_layers.inverse_exp_decay(hparams.startup_steps)
    pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
    cond = tf.less(tf.random_uniform([batch_size]), pc)

    # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
    # Call bottleneck layer to get the latents.
    # Returns embedded latents, discrete latents, loss and the embedding function.
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        latents_dense, latents_discrete, extra_loss = (bottleneck_layer(
            targets_c, hparams))
        extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

        # Call the autoregressive latent prediction model.
        _, latents_pred_loss = latent_prediction_model(inputs,
                                                       ed_attention_bias,
                                                       latents_discrete,
                                                       latents_dense,
                                                       hparams,
                                                       name="latent_pred")
        latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(
            cond)
        # Assign latent loss
        losses["latent_pred"] = latents_pred_loss
        losses["extra_loss"] = extra_loss
    else:
        latent_len = (hparams.img_len * hparams.img_len *
                      hparams.num_latents) / 2**(hparams.num_compress_steps)
        embed = functools.partial(discretization.parametrized_unbottleneck,
                                  hparams=hparams)
        latents_dense = tf.zeros(
            [batch_size, latent_len, 1, hparams.hidden_size])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs,
                                          ed_attention_bias, embed, hparams)
        latents_dense = embed(
            tf.one_hot(cache, depth=2**hparams.bottleneck_bits),
            hparams.hidden_size)

    latents_decoder = latents_dense
    if len(original_targets_shape) == 4:
        cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2))
        latents_decoder = tf.reshape(latents_decoder, [
            batch_size, cmp_img_len, cmp_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    # Decompress either using 1D or 2D upconvs.
    latents_decoder = decompress_fn(latents_decoder,
                                    hparams,
                                    name="decompress")
    # if we're operating in 2d space on images, then we're assuming that the
    # last dimension will not be a multiple of channels
    latents_decoder = tf.reshape(
        latents_decoder,
        shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder)
        masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        targets = mask * targets + (1.0 - mask) * latents_decoder
    else:
        targets = latents_decoder
    # reshape back to 4d here
    targets = tf.reshape(targets, original_targets_shape)
    if hparams.decode_autoregressive:
        # Transformer decoder, that goes from inputs->targets
        res = transformer_image_decoder(inputs, ed_attention_bias, targets,
                                        hparams, "decoder")
    else:
        res = targets

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
Example #38
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
def ae_transformer_internal(inputs, targets, target_space, hparams,
                            cache=None, predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  batch_size = common_layers.shape_list(inputs)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
  else:
    ed = None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, _ = bottleneck(
          targets_c, hparams, 2*2048, "vc")
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            tf.stop_gradient(inputs), tf.stop_gradient(ed),
            tf.stop_gradient(latents_dense), hparams, "extra")
        latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
        losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=latents_discrete, logits=latents_pred)
        losses["latent_pred"] = tf.reduce_mean(
            losses["latent_pred"] * 0.5 * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc")
          return bn
        pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn),
                           bn_inputs, lambda: inputs_c)
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc")
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(100000)
      masking *= common_layers.inverse_exp_decay(25000)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * 0.3
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)
      for i in xrange(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
      targets = mask * targets + (1.0 - mask) * d
    targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  if hparams.do_ae:
    res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        return residual_conv(res, 1, (5, 1), hparams, "refine")
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training only the extra model of latents after 400K steps.
    # Before we train only this, we decrease lr for other weights.
    latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step()))
    decreased_lr = common_layers.inverse_lin_decay(400000)
    losses["latent_pred"] *= tf.to_float(latent_time)
    losses["extra"] *= 1.0 - tf.to_float(latent_time)
    decreased_lr_res = tf.stop_gradient(decreased_lr * res)
    decreased_lr_res += (1.0 - decreased_lr) * res
    res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res)
  return res, losses, cache
Example #40
0
  def get_scheduled_sample_func(self, batch_size):
    """Creates a function for scheduled sampling based on given hparams."""
    with tf.variable_scope("scheduled_sampling_func", reuse=tf.AUTO_REUSE):
      iter_num = self.get_iteration_num()

      # Simple function to bypass scheduled sampling in gt or pred only modes.
      def scheduled_sampling_simple(ground_truth_x, generated_x,
                                    batch_size, scheduled_sample_var):
        del batch_size
        if scheduled_sample_var:
          return ground_truth_x
        return generated_x

      mode = self.hparams.scheduled_sampling_mode
      if mode == "ground_truth_only":
        scheduled_sampling_func = scheduled_sampling_simple
        scheduled_sampling_func_var = True
      elif mode == "prediction_only":
        scheduled_sampling_func = scheduled_sampling_simple
        scheduled_sampling_func_var = False
      elif mode == "prob":
        decay_steps = self.hparams.scheduled_sampling_decay_steps
        probability = tf.train.polynomial_decay(
            1.0, iter_num, decay_steps, 0.0)
        scheduled_sampling_func = common_video.scheduled_sample_prob
        scheduled_sampling_func_var = probability
      elif mode == "prob_inverse_exp":
        decay_steps = self.hparams.scheduled_sampling_decay_steps
        probability = common_layers.inverse_exp_decay(
            decay_steps, step=iter_num)
        probability *= self.hparams.scheduled_sampling_max_prob
        probability = 1.0 - probability
        scheduled_sampling_func = common_video.scheduled_sample_prob
        scheduled_sampling_func_var = probability
      elif mode == "prob_inverse_lin":
        decay_steps = self.hparams.scheduled_sampling_decay_steps
        probability = common_layers.inverse_exp_decay(
            decay_steps // 4, step=iter_num)  # Very low at start.
        probability *= common_layers.inverse_lin_decay(
            decay_steps, step=iter_num)
        probability *= self.hparams.scheduled_sampling_max_prob
        probability = 1.0 - probability
        scheduled_sampling_func = common_video.scheduled_sample_prob
        scheduled_sampling_func_var = probability
      elif mode == "count":
        # Calculate number of ground-truth frames to pass in.
        k = self.hparams.scheduled_sampling_k
        num_ground_truth = tf.to_int32(
            tf.round(
                tf.to_float(batch_size) *
                (k / (k + tf.exp(tf.to_float(iter_num) / tf.to_float(k))))))
        scheduled_sampling_func = common_video.scheduled_sample_count
        scheduled_sampling_func_var = num_ground_truth
      else:
        raise ValueError("unknown scheduled sampling method: %s" % mode)

      if isinstance(scheduled_sampling_func_var, tf.Tensor):
        tf.summary.scalar("scheduled_sampling_var", scheduled_sampling_func_var)
      partial_func = functools.partial(
          scheduled_sampling_func,
          batch_size=batch_size,
          scheduled_sample_var=scheduled_sampling_func_var)
      return partial_func
Example #41
0
def gumbel_softmax(x,
                   name,
                   z_size,
                   mode,
                   softmax_k=0,
                   kl_warmup_steps=150000,
                   summary=True):
  """Gumbel softmax discretization bottleneck.

  Args:
    x: Input to the discretization bottleneck.
    name: Name for the bottleneck scope.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding function, discrete code and loss.
  """
  with tf.variable_scope(name):
    m = tf.layers.dense(x, 2**z_size, name="mask")
    if softmax_k > 0:
      m, kl = top_k_softmax(m, softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)

    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)

    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(
        tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
        lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = -tf.reduce_max(logsm, axis=-1)

    if summary:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))

    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, 2**z_size))

    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, 2**z_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = -tf.reduce_mean(d_variance)
    ret = s

    if mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
Example #42
0
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None):
  """Main step used for training."""
  # Encoder.
  inputs = common_layers.flatten4d3d(inputs)
  inputs, ed = encode(inputs, target_space, hparams, "input_enc")

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

  max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
  targets, _ = common_layers.pad_to_same_length(
      targets,
      max_targets_len_from_inputs,
      final_length_divisible_by=2**hparams.num_compress_steps)
  targets_c = compress(targets, hparams, "compress")
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    # Compress and bottleneck.
    latents_discrete_hot, extra_loss = vq_discrete_bottleneck(
        x=targets_c, hparams=hparams)
    latents_dense = vq_discrete_unbottleneck(
        latents_discrete_hot, hparams=hparams)
    latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c)
    latents_discrete = tf.argmax(latents_discrete_hot, axis=-1)
    tf.summary.histogram("codes", tf.reshape(latents_discrete[:, 0, :], [-1]))
    losses["extra"] = extra_loss

    # Extra loss predicting latent code from input.
    latents_pred = decode_transformer(inputs, ed, latents_dense, hparams,
                                      "extra")
    latent_pred_loss = get_latent_pred_loss(latents_pred, latents_discrete_hot,
                                            hparams)
    losses["latent_pred"] = tf.reduce_mean(latent_pred_loss)
  else:
    latent_len = common_layers.shape_list(targets_c)[1]
    embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams)
    latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed,
                                    hparams)
    cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
    latents_dense = embed(cache_hot)

  # Postprocess.
  d = latents_dense
  pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
  pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
  latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

  # Decompressing the dense latents
  for i in range(hparams.num_compress_steps):
    j = hparams.num_compress_steps - i - 1
    d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
    d = decompress_step(d, hparams, i > 0, "decompress_%d" % j)

  masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
  masking *= common_layers.inverse_exp_decay(
      hparams.mask_startup_steps // 4)  # Not much at start.
  masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    masking = 1.0
  mask = tf.less(masking,
                 tf.random_uniform(common_layers.shape_list(targets)[:-1]))
  mask = tf.expand_dims(tf.to_float(mask), 3)

  # targets is always [batch, length, 1, depth]
  targets = mask * targets + (1.0 - mask) * d

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  latent_time = tf.less(hparams.mask_startup_steps,
                        tf.to_int32(tf.train.get_global_step()))
  losses["latent_pred"] *= tf.to_float(latent_time)
  return res, losses, cache
Example #43
0
  def model_fn(self, features, skip=False, last_position_only=False):
    """Computes the entire model and produces sharded logits and losses.

    Args:
      features: A dictionary of feature name to tensor.
      skip: a boolean, if we're just dummy-calling and actually skip this model
        (but we need to create variables to not confuse distributed training).
      last_position_only: a boolean, compute logits for only the last position.

    Returns:
      sharded_logits: a list of `Tensor`s, one per datashard.
      losses: a dictionary: {loss-name (string): floating point `Scalar`}.
    """
    start_time = time.time()
    dp = self._data_parallelism

    sharded_features = self._shard_features(features)

    # Construct the model bottom for inputs.
    transformed_features = {}
    all_previous_modalities = []

    for key, input_modality in six.iteritems(
        self._problem_hparams.input_modality):
      previous_modalities = [
          self._hparams.problems[i].input_modality[key].name
          for i in xrange(self._problem_idx)
      ]
      all_previous_modalities.extend(previous_modalities)
      do_reuse = input_modality.name in all_previous_modalities
      with tf.variable_scope(input_modality.name, reuse=do_reuse):
        transformed_features[key] = input_modality.bottom_sharded(
            sharded_features[key], dp)
      all_previous_modalities.append(input_modality.name)

    # Target space id just gets copied to every shard.
    if "target_space_id" in features:
      transformed_features["target_space_id"] = [features["target_space_id"]
                                                ] * self._num_datashards

    # Targets are transformed by the autoregressive part of the modality
    previous_tgt_modalities = [
        self._hparams.problems[i].target_modality.name
        for i in xrange(self._problem_idx)
    ]
    all_previous_modalities.extend(previous_tgt_modalities)

    target_modality = self._problem_hparams.target_modality
    target_reuse = target_modality.name in previous_tgt_modalities
    with tf.variable_scope(target_modality.name, reuse=target_reuse):
      transformed_features["targets"] = target_modality.targets_bottom_sharded(
          sharded_features["targets"], dp)

    # Allows later access to pre-embedding raw targets.
    transformed_features["raw_targets"] = sharded_features["targets"]

    # Construct the model body.
    with tf.variable_scope("body", reuse=self._problem_idx > 0):
      if skip:
        body_outputs = transformed_features["targets"]
        losses = {"extra": 0.0}
      else:
        body_outputs, losses = self.model_fn_body_sharded(
            transformed_features)
        if not isinstance(losses, dict):  # If it's a single extra loss.
          losses = {"extra": losses}

    with tf.variable_scope(target_modality.name, reuse=target_reuse):
      if not last_position_only:
        sharded_logits = target_modality.top_sharded(
            body_outputs, sharded_features["targets"], dp)
        training_loss = target_modality.loss_sharded(
            sharded_logits, sharded_features["targets"], dp)

        training_loss *= self._problem_hparams.loss_multiplier
      else:
        # Take body outputs for the last position only, and targets too.
        # TODO(lukaszkaiser): warning, this doesn't work for all modalities!
        last_position_body_outputs = [
            tf.expand_dims(body_shard[:, -1, :, :], axis=[1])
            for body_shard in body_outputs
        ]
        last_position_targets = [
            tf.expand_dims(target_shard[:, -1:, :, :], axis=[1])
            for target_shard in sharded_features["targets"]
        ]
        sharded_logits = target_modality.top_sharded(last_position_body_outputs,
                                                     last_position_targets,
                                                     self._data_parallelism)
        training_loss = None
    losses["training"] = training_loss

    # Scheduled sampling.
    do_scheduled_sampling = (  # Only do it if training and set for it.
        self._hparams.scheduled_sampling_prob > 0.0 and
        self._hparams.mode == tf.estimator.ModeKeys.TRAIN and
        not skip)
    if do_scheduled_sampling:

      def sample(x):
        """Multinomial sampling from a n-dimensional tensor."""
        vocab_size = target_modality.top_dimensionality
        samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1)
        reshaped_samples = tf.reshape(samples, tf.shape(x)[:-1])
        return tf.to_int32(reshaped_samples)

      def mix_gold_sampled(gold_targets, sampled_targets):
        return tf.where(
            tf.less(tf.random_uniform(tf.shape(sampled_targets)),
                    self._hparams.scheduled_sampling_gold_mixin_prob),
            gold_targets, sampled_targets)

      def sampled_results():
        """Generate scheduled sampling results."""
        sampled_targets = dp(sample, sharded_logits)
        new_targets = dp(mix_gold_sampled,
                         sharded_features["targets"], sampled_targets)
        new_features = transformed_features
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          with tf.variable_scope(target_modality.name):
            new_features["targets"] = target_modality.targets_bottom_sharded(
                new_targets, dp)
          with tf.variable_scope("body"):
            body_outputs, losses = self.model_fn_body_sharded(new_features)
            if not isinstance(losses, dict):  # If it's a single extra loss.
              losses = {"extra": losses}
          with tf.variable_scope(target_modality.name):
            new_sharded_logits = target_modality.top_sharded(
                body_outputs, sharded_features["targets"], dp)
            training_loss = target_modality.loss_sharded(
                sharded_logits, sharded_features["targets"], dp)
            training_loss *= self._problem_hparams.loss_multiplier
          losses["training"] = training_loss
        return new_sharded_logits, losses
      # Run the above conditionally.
      prob = self._hparams.scheduled_sampling_prob
      prob *= common_layers.inverse_exp_decay(
          self._hparams.scheduled_sampling_warmup_steps, min_value=0.001)
      sharded_logits, losses = tf.cond(
          tf.less(tf.random_uniform([]), prob),
          sampled_results,
          lambda: (sharded_logits, losses))

    tf.logging.info("This model_fn took %.3f sec." % (time.time() - start_time))
    return sharded_logits, losses
Example #44
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """Auto-encoder using a Transformer decoder and a prior over latent sequences.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None.
    targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2
      dimensions denoting sequence length.
    target_space: int. Used for encoding inputs under a target space id.
    hparams: tf.contrib.training.HParams.
    cache: Tensor of shape [batch, length] or None.
    predict_mask: Tensor masking whether to use gold targets or predictions.

  Returns:
    decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting
      pre-logit activations. After a transformation (`top` in `T2TModel`), it is
      used with targets to compute the "training" (reconstruction) loss.
    losses: dict of str to Tensors. There are three loss terms: "extra",
      "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter
      two are Tensors of shape [batch].
    cache: Tensor of shape [batch, length], either the same as cache, or newly
      computed if the cache input is None.
  """
  original_targets_shape = common_layers.shape_list(targets)
  batch_size = original_targets_shape[0]
  if len(original_targets_shape) == 4:
    compress_fn = compress_encoder_2d
    decompress_fn = decompress_decoder_2d
  else:
    compress_fn = compress_encoder_1d
    decompress_fn = decompress_decoder_1d

  ed_attention_bias = None
  if inputs is not None:
    inputs, ed_attention_bias = transformer_text_encoder(
        inputs, target_space, hparams, name="input_encoder")

  losses = {"extra": 0.,
            "extra_loss": 0.,
            "latent_pred": 0.}
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    targets_compressed = compress_fn(targets, hparams, name="compress")

    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      scale = common_layers.inverse_exp_decay(hparams.startup_steps)
    else:
      scale = 1.0
    scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale))

    latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer(
        targets_compressed, hparams)
    extra_loss = scale * tf.reduce_mean(extra_loss)

    _, latents_pred_loss = latent_prediction_model(
        inputs, ed_attention_bias, latents_discrete, latents_dense, hparams,
        name="latent_pred")
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss)
    latents_pred_loss *= tf.to_float(latent_time)

    # Apply dropout noise for each data point and time step.
    latents_dense_shape = common_layers.shape_list(latents_dense)
    latents_dense = tf.nn.dropout(
        latents_dense,
        keep_prob=1 - hparams.latent_dropout,
        noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1])

    # TODO(trandustin): Can we combine extra and extra_loss?
    losses = {"extra": 0.,
              "extra_loss": extra_loss,
              "latent_pred": latents_pred_loss}
  else:
    # Set the latent length, which is num_latents times the number of latent
    # pixels. The number of latent pixels is determined by a compression factor
    # on the number of image pixels.
    latent_len = ((hparams.img_len * hparams.img_len * hparams.num_latents) /
                  (2**hparams.num_compress_steps))
    _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams)
    latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense,
                                    inputs,
                                    ed_attention_bias,
                                    embed_fn,
                                    hparams)
    cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
    latents_dense = embed_fn(cache_one_hot, hparams.hidden_size)

  if len(original_targets_shape) == 4:
    compressed_img_len = (hparams.img_len //
                          2**(hparams.num_compress_steps // 2))
    latents_dense = tf.reshape(latents_dense,
                               [batch_size,
                                compressed_img_len,
                                compressed_img_len,
                                hparams.num_latents * hparams.hidden_size])

  latents_dense = decompress_fn(latents_dense, hparams, name="decompress")
  latents_dense = tf.reshape(
      latents_dense,
      [-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

  if hparams.use_gold_targets:
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      masking = predict_mask
    else:
      masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
    targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)
    mask = tf.less(masking,
                   tf.random_uniform(common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 2)
    latents_dense = mask * targets + (1.0 - mask) * latents_dense

  latents_dense = tf.reshape(latents_dense, original_targets_shape)
  if hparams.decode_autoregressive:
    decoder_output = transformer_image_decoder(
        latents_dense, inputs, ed_attention_bias, hparams, name="decoder")
  else:
    decoder_output = latents_dense
  return decoder_output, losses, cache
Example #45
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
    if hparams.do_ae:
        # flatten here
        original_targets_shape = tf.shape(targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        targets_c = compress(targets, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
                x=targets_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)
                losses["latent_pred"] = tf.reduce_mean(latent_pred_loss *
                                                       tf.to_float(cond))
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    (inputs_c - targets_c)**2) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _ = hparams.bottleneck(
                            x=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            name="vc",
                            mode=hparams.mode)
                    return bn

                inputs_c = bn_inputs
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _ = hparams.bottleneck(
                    x=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc",
                    mode=hparams.mode)
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed = hparams.bottleneck(
                    x=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
        pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
        latents_dense = tf.pad(latents_dense,
                               [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        if hparams.task == "translate":
            targets = tf.concat([tf.reverse(latents_dense, [1]), targets],
                                axis=1)

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.task == "translate":
            res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
Example #46
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  if inputs is not None:
    batch_size = common_layers.shape_list(inputs)[0]
  else:
    batch_size = common_layers.shape_list(targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
    inputs_ex, ed_ex = inputs, ed
  else:
    ed, inputs_ex, ed_ex = None, None, None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    # flatten here
    original_targets_shape = tf.shape(targets)
    if hparams.task == "image":
      cia.maybe_reshape_4d_to_3d(targets)
    if hparams.task == "translate":
      max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    else:
      assert hparams.task == "image"
      max_targets_len_from_inputs = targets
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    if hparams.word_dropout:
      mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                               minval=0.0, maxval=1.0)
      targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                               tf.zeros_like(targets))
    else:
      targets_noisy = targets
    targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
          x=targets_c,
          filter_size=hparams.compress_filter_size,
          name="vc",
          mode=hparams.mode)
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps)
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            inputs_ex, ed_ex,
            embed(latents_discrete), hparams, "extra",
            task="translate")
        _, latent_pred_loss = ae_latent_softmax(
            latents_pred, tf.stop_gradient(latents_discrete), hparams)
        losses["latent_pred"] = tf.reduce_mean(
            latent_pred_loss * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = hparams.bottleneck(
                x=inputs_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
          return bn
        inputs_c = bn_inputs
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = hparams.bottleneck(
            x=inputs_c,
            filter_size=hparams.compress_filter_size,
            name="vc",
            mode=hparams.mode)
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = hparams.bottleneck(
            x=targets_c, filter_size=hparams.compress_filter_size, name="vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(
              latents_dense, inputs_ex, ed_ex, embed, 16, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # decompressing the dense latents
    for i in range(hparams.num_compress_steps):
      j = hparams.num_compress_steps - i - 1
      d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
      if hparams.do_attend_decompress:
        d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
      d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
      masking *= common_layers.inverse_exp_decay(
          hparams.mask_startup_steps // 4)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * hparams.unmasked_percentage
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.use_predict_mask:
        masking = predict_mask
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)

      # targets is always [batch, length, 1, depth]
      targets = mask * targets + (1.0 - mask) * d
      # reshape back to 4d here
      if hparams.task == "image":
        targets = tf.reshape(targets, original_targets_shape)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder",
                           causal=hparams.causal)
  if hparams.do_ae:
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        # return residual_conv(res, 1, (5, 1), hparams, "refine")
        r, _ = encode(tf.squeeze(res, axis=[2]),
                      target_space, hparams, "refine_enc")
        return tf.expand_dims(r, axis=2)
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training the extra model of latents after mask_startup_steps.
    nonlatent_steps = hparams.mask_startup_steps
    latent_time = tf.less(nonlatent_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  return res, losses, cache
Example #47
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind='dvq',
                        num_blocks=2,
                        reshape_method='slice',
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True,
                        dp_strength=1.0,
                        dp_decay=1.0,
                        dp_alpha=0.5):
  """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector quantization.
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).
    dp_strength: Strength of Dirichlet Process loss prior (Default: 1.0).
    dp_decay: Decay the dp_strength using an exponential decay using this
      term (Default: 1.0).
    dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
  block_v_size = None
  if bottleneck_kind == 'dvq':
    # Define the dvq parameters
    assert means is not None

    # Check block dimensions add up
    if hidden_size % num_blocks != 0:
      raise ValueError('num_blocks does not divide hidden size')

    if 2**z_size % num_blocks != 0:
      raise ValueError('num_blocks does not divide embedding table size')

    block_v_size = 2**(z_size / num_blocks)
    block_v_size = int(block_v_size)

    # Set the reshape method corresponding to projections or slices
    if reshape_method == 'slice':
      reshape_fn = partial(
          slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks)
    elif reshape_method == 'project':
      if projection_tensors is None:
        raise ValueError(
            'Projection tensors is None for reshape_method project')
      reshape_fn = partial(
          project_hidden,
          projection_tensors=projection_tensors,
          hidden_size=hidden_size,
          num_blocks=num_blocks)
    else:
      raise ValueError('Unknown reshape_method')

    # Check if the ema settings make sense
    if ema:
      if ema_count is None:
        raise ValueError('ema_count is None but ema is True')
      if ema_means is None:
        raise ValueError('ema_means is None but ema is True')

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    l = tf.constant(0.0)
    if bottleneck_kind == 'dense':
      c = tf.layers.dense(x, z_size, name='vcc')
      h1 = tf.layers.dense(c, filter_size, name='vch1')
    elif bottleneck_kind == 'vae':
      c, l, _, _ = vae(x, z_size, 'vae')
      h1 = tf.layers.dense(c, filter_size, name='vch1')
    elif bottleneck_kind == 'semhash':
      c = tf.layers.dense(x, z_size, name='vcc')
      y_clean = common_layers.saturating_sigmoid(c)
      if summary:
        tf.summary.histogram('y_clean', tf.reshape(y_clean, [-1]))
      if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.truncated_normal(
            common_layers.shape_list(c), mean=0.0, stddev=noise_dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(startup_steps * 2)
      pd *= discrete_mix
      pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(
          tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
          y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name='vch1a')
      h1b = tf.layers.dense(1.0 - c, filter_size, name='vch1b')
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    elif bottleneck_kind == 'gumbel-softmax':
      _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                 kl_warmup_steps, summary)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hidden_size, name='dae_dense')
    elif bottleneck_kind == 'dvq':
      x_reshaped = reshape_fn(x)
      x_means_hot, x_means, q_loss, e_loss = embedding_lookup(
          x_reshaped, means, num_blocks, block_v_size, random_top_k)

      # Get the discrete latent represenation
      x_means_idx = tf.argmax(x_means_hot, axis=-1)

      # Get the binary representation
      x_means_bits = int_to_bit(
          x_means_idx, num_bits=int(z_size / num_blocks), base=2)
      shape = common_layers.shape_list(x_means_bits)
      new_shape = shape[:-1]
      new_shape[-1] = z_size
      x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
      c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

      # Adjust shape of c
      shape_x = common_layers.shape_list(x)
      new_shape = shape_x[:-1]
      c = tf.reshape(c, new_shape)

      # Update the ema variables
      if ema:
        tf.logging.info('Using EMA with beta = {}'.format(beta))
        updated_ema_count = moving_averages.assign_moving_average(
            ema_count,
            tf.reduce_sum(
                tf.reshape(x_means_hot, shape=[-1, num_blocks, block_v_size]),
                axis=0),
            decay,
            zero_debias=False)

        # Adding a term that puts a Dirichlet prior over cluster probabilities
        # Hopefully it'll encourage rich get richer behaviors
        dp_prior_loss = 0.
        if dp_strength > 0.0:
          # Decay dp_strength over time to make it less important
          dp_strength = tf.train.exponential_decay(
              dp_strength,
              global_step=tf.to_int32(tf.train.get_global_step()),
              decay_steps=20000,
              decay_rate=dp_decay)
          dp_count = ema_count + dp_alpha
          p = dp_count / tf.reduce_sum(dp_count, 1, keepdims=True)
          dp_prior_loss = tf.log(p)
          dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss)
          dp_prior_loss /= (num_blocks * block_v_size)

        x_means_hot_flat = tf.reshape(
            x_means_hot, shape=[-1, num_blocks, block_v_size])
        dw = tf.matmul(
            tf.transpose(x_means_hot_flat, perm=[1, 2, 0]),
            tf.transpose(x_reshaped, perm=[1, 0, 2]))
        updated_ema_means = moving_averages.assign_moving_average(
            ema_means, dw, decay, zero_debias=False)
        n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True)
        updated_ema_count = ((updated_ema_count + epsilon) /
                             (n + 2**z_size * epsilon) * n)
        updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)

        with tf.control_dependencies([e_loss]):
          update_means = tf.assign(means, updated_ema_means)
          with tf.control_dependencies([update_means]):
            l = beta * e_loss + dp_strength * dp_prior_loss
      else:
        l = q_loss + beta * e_loss

      x_means = tf.reshape(x_means, shape_x)
      x_reshaped = tf.reshape(x_reshaped, shape_x)
      h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
    else:
      raise ValueError('Unknown discretization method.')

    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name='vch2')
    res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name='vcfin')

    embed_fn = partial(
        embed,
        hidden_size=hidden_size,
        z_size=z_size,
        filter_size=filter_size,
        name=name,
        bottleneck_kind=bottleneck_kind,
        num_blocks=num_blocks,
        block_v_size=block_v_size,
        means=means)
    return res, c, l, embed_fn
Example #48
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """Auto-encoder using transformer decoder and prior over latents."""
  losses = {"extra": 0., "latent_pred": 0.}

  # Reshape image targets as 4d tensor.
  original_targets_shape = common_layers.shape_list(targets)
  batch_size = original_targets_shape[0]
  if len(original_targets_shape) == 4:
    compress_fn = compress_encoder_2d
    decompress_fn = decompress_decoder_2d
  else:
    compress_fn = compress_encoder_1d
    decompress_fn = decompress_decoder_1d

  # Input Encoder if present.
  ed_attention_bias = None
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed_attention_bias = transformer_text_encoder(
        inputs, target_space, hparams, "input_enc")

  # Encode targets to compute targets compressed.
  targets_c = compress_fn(targets, hparams, "compress")
  targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

  # Following code creates an exponentially decaying variable based on which
  # we rescale the loss values.
  pc = common_layers.inverse_exp_decay(hparams.startup_steps)
  pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
  cond = tf.less(tf.random_uniform([batch_size]), pc)

  # Call bottleneck layer, that takes encoder output and outputs the latents.
  # Returns embedded latents, discrete latent codes, loss.
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    latents_dense, latents_discrete, extra_loss = (
        bottleneck_layer(targets_c, hparams))
    extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

    _, latents_pred_loss = latent_prediction_model(
        inputs,
        ed_attention_bias,
        latents_discrete,
        latents_dense,
        hparams,
        name="latent_pred")
    latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond)

    latents_shape = common_layers.shape_list(latents_dense)
    latents_dense = tf.nn.dropout(
        latents_dense, 1 - hparams.latent_dropout,
        noise_shape=[latents_shape[0], latents_shape[1], 1])

    losses["extra_loss"] = extra_loss
    losses["latent_pred"] = latents_pred_loss

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  else:
    latent_len = (
        hparams.img_len * hparams.img_len * hparams.num_latents) / 2**(
            hparams.num_compress_steps)
    embed = functools.partial(
        discretization.parametrized_unbottleneck, hparams=hparams)
    latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias,
                                    embed, hparams)
    latents_dense = embed(
        tf.one_hot(cache, depth=2**hparams.bottleneck_bits),
        hparams.hidden_size)

  latents_decoder = latents_dense
  if len(original_targets_shape) == 4:
    compressed_img_len = hparams.img_len / 2**(hparams.num_compress_steps // 2)
    latents_decoder = tf.reshape(latents_decoder,
                                 [batch_size,
                                  compressed_img_len,
                                  compressed_img_len,
                                  hparams.num_latents * hparams.hidden_size])

  latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress")
  # if we're operating in 2d space on images, then we're assuming that the
  # last dimension will not be a multiple of channels
  output = tf.reshape(
      latents_decoder,
      shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

  if hparams.use_gold_targets:
    masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      masking = predict_mask
    mask = tf.less(masking, tf.random_uniform(
        common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 2)
    output = mask * targets + (1.0 - mask) * output

  # reshape back to 4d here
  output = tf.reshape(output, original_targets_shape)
  if hparams.decode_autoregressive:
    # Transformer decoder, that goes from inputs->targets
    decoder_output = transformer_image_decoder(
        output, inputs, ed_attention_bias, hparams, "decoder")
  else:
    decoder_output = output
  return decoder_output, losses, cache
Example #49
0
  def _model_fn(self, features, skip=False, force_full_predict=False):
    """Computes the entire model and produces sharded logits and losses.

    Args:
      features: A dictionary of feature name to tensor.
      skip: a Boolean, if we're just dummy-calling and actually skip this model
        (but we need to create variables to not confuse distributed training).
      force_full_predict: a Boolean, if set, then last-position-only
        optimizations are not used even when allowed and in PREDICT mode.

    Returns:
      logits: `Tensor`
      losses: a dictionary: {loss-name (string): floating point `Scalar`}.
    """
    start_time = time.time()
    dp = self._data_parallelism

    sharded_features = self._shard_features(features)

    # Construct the model bottom for inputs.
    transformed_features = {}
    all_previous_modalities = []

    for key, input_modality in six.iteritems(
        self._problem_hparams.input_modality):
      previous_modalities = [
          self.hparams.problems[i].input_modality[key].name
          for i in xrange(self._problem_idx)
      ]
      all_previous_modalities.extend(previous_modalities)
      do_reuse = input_modality.name in all_previous_modalities
      transformed_features[key + "_raw"] = sharded_features[key]
      with tf.variable_scope(input_modality.name, reuse=do_reuse):
        transformed_features[key] = input_modality.bottom_sharded(
            sharded_features[key], dp)
      all_previous_modalities.append(input_modality.name)

    # Target space id just gets copied to every shard.
    if "target_space_id" in features:
      transformed_features["target_space_id"] = [features["target_space_id"]
                                                ] * self._num_datashards

    # For features without a modality ending in "_raw", we pass them raw.
    for key, feature in sharded_features.items():
      if key not in transformed_features and key.endswith("_raw"):
        transformed_features[key] = feature

    # Targets are transformed by the autoregressive part of the modality
    previous_tgt_modalities = [
        self.hparams.problems[i].target_modality.name
        for i in xrange(self._problem_idx)
    ]
    all_previous_modalities.extend(previous_tgt_modalities)

    target_modality = self._problem_hparams.target_modality
    target_reuse = target_modality.name in previous_tgt_modalities
    with tf.variable_scope(target_modality.name, reuse=target_reuse):
      transformed_features["targets"] = target_modality.targets_bottom_sharded(
          sharded_features["targets"], dp)

    # Allows later access to pre-embedding raw targets.
    transformed_features["targets_raw"] = sharded_features["targets"]

    # Construct the model body.
    with tf.variable_scope("body", reuse=self._problem_idx > 0):
      if skip:
        body_outputs = transformed_features["targets"]
        losses = {"extra": 0.0}
      else:
        body_outputs, losses = self.model_fn_body_sharded(transformed_features)
        if not isinstance(losses, dict):  # If it's a single extra loss.
          losses = {"extra": losses}

    with tf.variable_scope(target_modality.name, reuse=target_reuse):
      last_only = (target_modality.top_is_pointwise and
                   self.hparams.mode == tf.estimator.ModeKeys.PREDICT and
                   not force_full_predict)
      if not last_only:
        sharded_logits = target_modality.top_sharded(
            body_outputs, sharded_features["targets"], dp)
        training_loss = target_modality.loss_sharded(
            sharded_logits, sharded_features["targets"], dp)

        training_loss *= self._problem_hparams.loss_multiplier
      else:
        # Take body outputs for the last position only, and targets too.
        last_position_body_outputs = [
            tf.expand_dims(body_shard[:, -1, :, :], axis=[1])
            for body_shard in body_outputs
        ]
        last_position_targets = [
            tf.expand_dims(target_shard[:, -1:, :, :], axis=[1])
            for target_shard in sharded_features["targets"]
        ]
        sharded_logits = target_modality.top_sharded(last_position_body_outputs,
                                                     last_position_targets,
                                                     self._data_parallelism)
        training_loss = None
    losses["training"] = training_loss

    # Scheduled sampling.
    do_scheduled_sampling = (  # Only do it if training and set for it.
        self.hparams.scheduled_sampling_prob > 0.0 and
        self.hparams.mode == tf.estimator.ModeKeys.TRAIN and not skip)
    if do_scheduled_sampling:

      def sample(x):
        """Multinomial sampling from a n-dimensional tensor."""
        vocab_size = target_modality.top_dimensionality
        samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1)
        reshaped_samples = tf.reshape(samples, common_layers.shape_list(x)[:-1])
        return tf.to_int32(reshaped_samples)

      def mix_gold_sampled(gold_targets, sampled_targets):
        return tf.where(
            tf.less(
                tf.random_uniform(common_layers.shape_list(sampled_targets)),
                self.hparams.scheduled_sampling_gold_mixin_prob), gold_targets,
            sampled_targets)

      def sampled_results():
        """Generate scheduled sampling results."""
        sampled_targets = dp(sample, sharded_logits)
        new_targets = dp(mix_gold_sampled, sharded_features["targets"],
                         sampled_targets)
        new_features = transformed_features
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          with tf.variable_scope(target_modality.name):
            new_features["targets"] = target_modality.targets_bottom_sharded(
                new_targets, dp)
          with tf.variable_scope("body"):
            body_outputs, losses = self.model_fn_body_sharded(new_features)
            if not isinstance(losses, dict):  # If it's a single extra loss.
              losses = {"extra": losses}
          with tf.variable_scope(target_modality.name):
            new_sharded_logits = target_modality.top_sharded(
                body_outputs, sharded_features["targets"], dp)
            training_loss = target_modality.loss_sharded(
                sharded_logits, sharded_features["targets"], dp)
            training_loss *= self._problem_hparams.loss_multiplier
          losses["training"] = training_loss
        return new_sharded_logits, losses

      # Run the above conditionally.
      prob = self.hparams.scheduled_sampling_prob
      prob *= common_layers.inverse_exp_decay(
          self.hparams.scheduled_sampling_warmup_steps, min_value=0.001)
      sharded_logits, losses = tf.cond(
          tf.less(tf.random_uniform([]), prob), sampled_results,
          lambda: (sharded_logits, losses))

    if not context.in_eager_mode():
      tf.logging.info("This model_fn took %.3f sec." %
                      (time.time() - start_time))
    return sharded_logits, losses
Example #50
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind="dvq",
                        num_blocks=2,
                        num_residuals=1,
                        reshape_method="slice",
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        soft_em=False,
                        num_samples=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True):
  """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector
      quantization (Default: 2).
    num_residuals: Number of residual units used to compute nearest
      neighbors (Default: 1).
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples for soft EM (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
  block_v_size = None
  if bottleneck_kind == "dvq":
    # Define the dvq parameters
    assert means is not None

    # Check block dimensions add up
    if hidden_size % num_blocks != 0:
      raise ValueError("num_blocks does not divide hidden size")

    if z_size % num_residuals != 0:
      raise ValueError("num_residuals does not divide embedding table size")

    z_size_per_residual = int(z_size / num_residuals)

    if z_size_per_residual % num_blocks != 0:
      raise ValueError("num_blocks does not divide embedding table size")

    block_v_size = 2**(z_size_per_residual / num_blocks)
    block_v_size = int(block_v_size)

    # Set the reshape method corresponding to projections or slices
    if reshape_method == "slice":
      reshape_fn = partial(
          slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks)
    elif reshape_method == "project":
      if projection_tensors is None:
        raise ValueError(
            "Projection tensors is None for reshape_method project")
      reshape_fn = partial(
          project_hidden,
          projection_tensors=projection_tensors,
          hidden_size=hidden_size,
          num_blocks=num_blocks)
    else:
      raise ValueError("Unknown reshape_method")

    # Check if the ema settings make sense
    if ema:
      if ema_count is None:
        raise ValueError("ema_count is None but ema is True")
      if ema_means is None:
        raise ValueError("ema_means is None but ema is True")

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    l = tf.constant(0.0)
    if bottleneck_kind == "dense":
      c = tf.layers.dense(x, z_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    elif bottleneck_kind == "vae":
      c, l, _, _ = vae(x, z_size, "vae")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    elif bottleneck_kind == "semhash":
      c = tf.layers.dense(x, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      if summary:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.truncated_normal(
            common_layers.shape_list(c), mean=0.0, stddev=noise_dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(startup_steps * 2)
      pd *= discrete_mix
      pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(
          tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
          y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    elif bottleneck_kind == "gumbel-softmax":
      _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                 kl_warmup_steps, summary)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
    elif bottleneck_kind == "dvq":
      x_reshaped = reshape_fn(x)
      x_res = x_reshaped
      x_means_hot = []
      x_means = 0
      l = 0
      for i in range(num_residuals):
        x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup(
            x_res, means[i], num_blocks, block_v_size, random_top_k, soft_em,
            num_samples)
        # Update the ema variables
        if ema:
          tf.logging.info("Using EMA with beta = {}".format(beta))
          updated_ema_count_res = moving_averages.assign_moving_average(
              ema_count[i],
              tf.reduce_sum(
                  tf.reshape(
                      x_means_hot_res, shape=[-1, num_blocks, block_v_size]),
                  axis=0),
              decay,
              zero_debias=False)

          dw = tf.matmul(
              tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
              tf.transpose(x_res, perm=[1, 0, 2]))

          updated_ema_means_res = moving_averages.assign_moving_average(
              ema_means[i], dw, decay, zero_debias=False)
          n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True)
          updated_ema_count_res = ((updated_ema_count_res + epsilon) /
                                   (n + 2**z_size * epsilon) * n)
          # pylint: disable=g-no-augmented-assignment
          updated_ema_means_res = updated_ema_means_res / tf.expand_dims(
              updated_ema_count_res, axis=-1)
          # pylint: enable=g-no-augmented-assignment

          with tf.control_dependencies([e_loss_res]):
            update_means_res = tf.assign(means[i], updated_ema_means_res)
            with tf.control_dependencies([update_means_res]):
              l += beta * e_loss_res
        else:
          l += q_loss_res + beta * e_loss_res

        # Update the residuals
        x_res -= x_means_res
        x_means += x_means_res
        x_means_hot.append(x_means_hot_res)

      # Get the discrete latent representation
      x_means_hot = tf.stack(x_means_hot, axis=1)
      x_means_idx = tf.argmax(x_means_hot, axis=-1)

      # Get the binary representation
      x_means_bits = int_to_bit(
          x_means_idx,
          num_bits=int(z_size / (num_residuals * num_blocks)),
          base=2)
      shape = common_layers.shape_list(x_means_bits)
      new_shape = shape[:-2]
      new_shape[-1] = z_size
      x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
      c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

      # Adjust shape of c
      shape_x = common_layers.shape_list(x)
      new_shape = shape_x[:-1]
      c = tf.reshape(c, new_shape)

      # If we are doing soft EM then c is x_means_hot
      if soft_em:
        c = x_means_hot
        new_shape.append(block_v_size)
        c = tf.reshape(c, new_shape)

      x_means = tf.reshape(x_means, shape_x)
      x_reshaped = tf.reshape(x_reshaped, shape_x)
      h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
    else:
      raise ValueError("Unknown discretization method.")

    res = h1

    embed_fn = partial(
        embed,
        hidden_size=hidden_size,
        z_size=z_size,
        filter_size=filter_size,
        name=name,
        bottleneck_kind=bottleneck_kind,
        soft_em=soft_em,
        num_blocks=num_blocks,
        num_residuals=num_residuals,
        block_v_size=block_v_size,
        means=means)
    return res, c, l, embed_fn