Beispiel #1
0
    def begin(self):
        """Load variables from checkpoint.

    New model variables have the following name foramt:
    new_model_scope/old_model_scope/xxx/xxx:0 To find the map of
    name to variable, need to strip the new_model_scope and then
    match the old_model_scope and remove the suffix :0.

    """
        variables_to_restore = contrib.framework().get_variables_to_restore(
            include=self._include, exclude=self._exclude)
        # remove new_model_scope from variable name prefix
        assignment_map = {
            variable.name[len(self._new_model_scope):]: variable
            for variable in variables_to_restore
            if variable.name.startswith(self._new_model_scope)
        }
        # remove :0 from variable name suffix
        assignment_map = {
            name.split(":")[0]: variable
            for name, variable in six.iteritems(assignment_map)
            if name.startswith(self._old_model_scope)
        }
        self._assignment_map = assignment_map

        tf.logging.info("restoring %d variables from checkpoint %s" %
                        (len(assignment_map), self._checkpoint_path))
        tf.train.init_from_checkpoint(self._checkpoint_path,
                                      self._assignment_map)
Beispiel #2
0
def make_input_fn_from_generator(gen):
  """Use py_func to yield elements from the given generator."""
  first_ex = six.next(gen)
  flattened = contrib.framework().nest.flatten(first_ex)
  types = [t.dtype for t in flattened]
  shapes = [[None] * len(t.shape) for t in flattened]
  first_ex_list = [first_ex]

  def py_func():
    if first_ex_list:
      example = first_ex_list.pop()
    else:
      example = six.next(gen)
    return contrib.framework().nest.flatten(example)

  def input_fn():
    flat_example = tf.py_func(py_func, [], types)
    _ = [t.set_shape(shape) for t, shape in zip(flat_example, shapes)]
    example = contrib.framework().nest.pack_sequence_as(first_ex, flat_example)
    return example

  return input_fn
Beispiel #3
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.models.research import glow_init_hook
from tensor2tensor.models.research import glow_ops
from tensor2tensor.utils import contrib
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow as tf

arg_scope = contrib.framework().arg_scope
add_arg_scope = contrib.framework().add_arg_scope

GLOW_DECODE_HPARAMS = ("identity_output=True,log_results=False,"
                       "decode_in_memory=True,display_decoded_images=True")


@registry.register_hparams
def glow_hparams():
    """Glow Hparams."""
    hparams = common_hparams.basic_params1()
    hparams.clip_grad_norm = None
    hparams.weight_decay = 0.0
    hparams.learning_rate_constant = 3e-4
    hparams.batch_size = 32
    # can be prev_level, prev_step or normal.
Beispiel #4
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {
        "extra": tf.constant(0.0),
        "latent_pred": tf.constant(0.0),
        "neg_q_entropy": tf.constant(0.0)
    }
    if hparams.do_ae:
        # flatten here
        original_targets = targets
        original_targets_shape = tf.shape(original_targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            if inputs is not None:
                max_targets_len_from_inputs = tf.concat([inputs, inputs],
                                                        axis=1)
            else:
                max_targets_len_from_inputs = targets
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        if hparams.word_shuffle:
            tf.logging.info("Using word shuffle with rate = {}".format(
                hparams.word_shuffle))
            targets_idx = tf.range(start=0,
                                   limit=common_layers.shape_list(targets)[1],
                                   delta=1)
            targets_idx = tf.to_float(targets_idx)
            noise = tf.random_uniform(
                shape=common_layers.shape_list(targets_idx),
                minval=0,
                maxval=1 + hparams.word_shuffle)
            targets_idx += noise
            permutation = contrib.framework().argsort(targets_idx)
            targets_permuted = tf.gather(targets, indices=permutation, axis=1)
            targets = targets_permuted
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        # Add positional information
        targets_shape = common_layers.shape_list(targets)
        targets = tf.reshape(
            targets, [targets_shape[0], targets_shape[1], targets_shape[3]])
        targets = common_attention.add_positional_embedding(
            targets, hparams.max_length, name="targets_position")
        targets = tf.reshape(targets, shape=targets_shape)
        if hparams.word_dropout:
            mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                                     minval=0.0,
                                     maxval=1.0)
            targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                                     tf.zeros_like(targets))
        else:
            targets_noisy = targets

        targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
                hparams.bottleneck(inputs=targets_c,
                                   filter_size=hparams.compress_filter_size,
                                   mode=hparams.mode,
                                   name="vc"))
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)

                # Scale by latent dimension for summary so we can compare across
                # batches.
                if _DO_SUMMARIES:
                    tf.summary.scalar("latent_pred_loss_mean",
                                      tf.reduce_mean(latent_pred_loss))
                if hparams.sum_over_latents:
                    latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

                losses["latent_pred"] = tf.reduce_mean(
                    latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
                losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    tf.squared_difference(inputs_c, targets_c)) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _, _ = hparams.bottleneck(
                            inputs=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            mode=hparams.mode,
                            name="vc")
                    return bn

                inputs_c = bn_inputs()
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _, _ = hparams.bottleneck(
                    inputs=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    mode=hparams.mode,
                    name="vc")
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed, _ = hparams.bottleneck(
                    inputs=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        d_shape = common_layers.shape_list(d)
        d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
        d = common_attention.add_positional_embedding(d,
                                                      hparams.max_length,
                                                      name="latents_position")
        d = tf.reshape(d, shape=d_shape)

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if inputs is not None and hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        else:
            targets = d

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)

    # res was generated from padded targets, which means it has some extra
    # elements. These can cause shape problems when computing loss with respect to
    # the original (unpadded) targets. So we remove their extra elements here.
    res = res[:, :original_targets_shape[1], :, :]

    data_dim = common_layers.shape_list(res)[1]
    latent_dim = common_layers.shape_list(targets_c)[1]
    return res, losses, cache, data_dim, latent_dim
Beispiel #5
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range  # pylint: disable=redefined-builtin

from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import contrib
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf

framework = contrib.framework(msg="warn")


@framework.deprecated("2018-09-15",
                      "Use Transformer, which supports decoder-only mode when "
                      "Transformer.has_input=False.")
@registry.register_model
class AttentionLM(t2t_model.T2TModel):
    """Attention net.  See file docstring."""
    def body(self, features):
        # Remove dropout if not training
        hparams = self._hparams
        targets = features["targets"]
        targets = tf.squeeze(targets, 2)

        (decoder_input,
Beispiel #6
0
 def input_fn():
   flat_example = tf.py_func(py_func, [], types)
   _ = [t.set_shape(shape) for t, shape in zip(flat_example, shapes)]
   example = contrib.framework().nest.pack_sequence_as(first_ex, flat_example)
   return example
Beispiel #7
0
 def py_func():
   if first_ex_list:
     example = first_ex_list.pop()
   else:
     example = six.next(gen)
   return contrib.framework().nest.flatten(example)