def begin(self): """Load variables from checkpoint. New model variables have the following name foramt: new_model_scope/old_model_scope/xxx/xxx:0 To find the map of name to variable, need to strip the new_model_scope and then match the old_model_scope and remove the suffix :0. """ variables_to_restore = contrib.framework().get_variables_to_restore( include=self._include, exclude=self._exclude) # remove new_model_scope from variable name prefix assignment_map = { variable.name[len(self._new_model_scope):]: variable for variable in variables_to_restore if variable.name.startswith(self._new_model_scope) } # remove :0 from variable name suffix assignment_map = { name.split(":")[0]: variable for name, variable in six.iteritems(assignment_map) if name.startswith(self._old_model_scope) } self._assignment_map = assignment_map tf.logging.info("restoring %d variables from checkpoint %s" % (len(assignment_map), self._checkpoint_path)) tf.train.init_from_checkpoint(self._checkpoint_path, self._assignment_map)
def make_input_fn_from_generator(gen): """Use py_func to yield elements from the given generator.""" first_ex = six.next(gen) flattened = contrib.framework().nest.flatten(first_ex) types = [t.dtype for t in flattened] shapes = [[None] * len(t.shape) for t in flattened] first_ex_list = [first_ex] def py_func(): if first_ex_list: example = first_ex_list.pop() else: example = six.next(gen) return contrib.framework().nest.flatten(example) def input_fn(): flat_example = tf.py_func(py_func, [], types) _ = [t.set_shape(shape) for t, shape in zip(flat_example, shapes)] example = contrib.framework().nest.pack_sequence_as(first_ex, flat_example) return example return input_fn
from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers from tensor2tensor.models.research import glow_init_hook from tensor2tensor.models.research import glow_ops from tensor2tensor.utils import contrib from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf arg_scope = contrib.framework().arg_scope add_arg_scope = contrib.framework().add_arg_scope GLOW_DECODE_HPARAMS = ("identity_output=True,log_results=False," "decode_in_memory=True,display_decoded_images=True") @registry.register_hparams def glow_hparams(): """Glow Hparams.""" hparams = common_hparams.basic_params1() hparams.clip_grad_norm = None hparams.weight_decay = 0.0 hparams.learning_rate_constant = 3e-4 hparams.batch_size = 32 # can be prev_level, prev_step or normal.
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = { "extra": tf.constant(0.0), "latent_pred": tf.constant(0.0), "neg_q_entropy": tf.constant(0.0) } if hparams.do_ae: # flatten here original_targets = targets original_targets_shape = tf.shape(original_targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": if inputs is not None: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: max_targets_len_from_inputs = targets else: assert hparams.task == "image" max_targets_len_from_inputs = targets if hparams.word_shuffle: tf.logging.info("Using word shuffle with rate = {}".format( hparams.word_shuffle)) targets_idx = tf.range(start=0, limit=common_layers.shape_list(targets)[1], delta=1) targets_idx = tf.to_float(targets_idx) noise = tf.random_uniform( shape=common_layers.shape_list(targets_idx), minval=0, maxval=1 + hparams.word_shuffle) targets_idx += noise permutation = contrib.framework().argsort(targets_idx) targets_permuted = tf.gather(targets, indices=permutation, axis=1) targets = targets_permuted targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) # Add positional information targets_shape = common_layers.shape_list(targets) targets = tf.reshape( targets, [targets_shape[0], targets_shape[1], targets_shape[3]]) targets = common_attention.add_positional_embedding( targets, hparams.max_length, name="targets_position") targets = tf.reshape(targets, shape=targets_shape) if hparams.word_dropout: mask = tf.random_uniform(shape=common_layers.shape_list(targets), minval=0.0, maxval=1.0) targets_noisy = tf.where(mask > hparams.word_dropout, targets, tf.zeros_like(targets)) else: targets_noisy = targets targets_c = compress(targets_noisy, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = ( hparams.bottleneck(inputs=targets_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc")) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) # Scale by latent dimension for summary so we can compare across # batches. if _DO_SUMMARIES: tf.summary.scalar("latent_pred_loss_mean", tf.reduce_mean(latent_pred_loss)) if hparams.sum_over_latents: latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2]) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( tf.squared_difference(inputs_c, targets_c)) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") return bn inputs_c = bn_inputs() ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed, _ = hparams.bottleneck( inputs=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense d_shape = common_layers.shape_list(d) d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]]) d = common_attention.add_positional_embedding(d, hparams.max_length, name="latents_position") d = tf.reshape(d, shape=d_shape) # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if inputs is not None and hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) else: targets = d res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) # res was generated from padded targets, which means it has some extra # elements. These can cause shape problems when computing loss with respect to # the original (unpadded) targets. So we remove their extra elements here. res = res[:, :original_targets_shape[1], :, :] data_dim = common_layers.shape_list(res)[1] latent_dim = common_layers.shape_list(targets_c)[1] return res, losses, cache, data_dim, latent_dim
from __future__ import absolute_import from __future__ import division from __future__ import print_function from six.moves import range # pylint: disable=redefined-builtin from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers from tensor2tensor.utils import contrib from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf framework = contrib.framework(msg="warn") @framework.deprecated("2018-09-15", "Use Transformer, which supports decoder-only mode when " "Transformer.has_input=False.") @registry.register_model class AttentionLM(t2t_model.T2TModel): """Attention net. See file docstring.""" def body(self, features): # Remove dropout if not training hparams = self._hparams targets = features["targets"] targets = tf.squeeze(targets, 2) (decoder_input,
def input_fn(): flat_example = tf.py_func(py_func, [], types) _ = [t.set_shape(shape) for t, shape in zip(flat_example, shapes)] example = contrib.framework().nest.pack_sequence_as(first_ex, flat_example) return example
def py_func(): if first_ex_list: example = first_ex_list.pop() else: example = six.next(gen) return contrib.framework().nest.flatten(example)