def padded_neg_log_perplexity(predictions, labels, weights_fn=common_layers.weights_nonzero): """Average log-perplexity exluding padding 0s. No smoothing.""" num, den = common_layers.padded_cross_entropy( predictions, labels, 0.0, weights_fn=weights_fn, reduce_sum=False) return (-num, den)
def testPaddingCrossEntropyFactored(self): vocab_size = 19 rows = 5 cols = 4 depth = 11 label_smoothing = 0.1 features = np.random.rand(rows, cols, depth) weights = np.random.rand(vocab_size, depth) labels = np.random.randint(0, vocab_size - 1, size=(rows, cols)) with self.test_session() as session: features = tf.to_float(features) weights = tf.to_float(weights) labels = tf.to_int32(labels) logits = tf.matmul( tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True) logits = tf.reshape(logits, [rows, cols, vocab_size]) loss_num, loss_den = common_layers.padded_cross_entropy( logits, labels, label_smoothing=label_smoothing, reduce_sum=False) factored_logits = common_layers.FactoredTensor(features, weights) loss_num_f, loss_den_f = common_layers.padded_cross_entropy_factored( factored_logits, labels=labels, label_smoothing=label_smoothing, reduce_sum=False) num, den, num_f, den_f = session.run( [loss_num, loss_den, loss_num_f, loss_den_f]) self.assertEqual(num.shape, (rows, cols)) self.assertEqual(den.shape, (rows, cols)) self.assertEqual(num_f.shape, (rows, cols)) self.assertEqual(den_f.shape, (rows, cols)) self.assertAllClose(num, num_f) self.assertAllClose(den, den_f)
def loss(self, top_out, targets): """Compute loss numerator and denominator for one shard of output.""" logits = top_out return common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, weights_fn=self.targets_weights_fn)
def loss(self, top_out, targets): """Compute loss numerator and denominator for one shard of output.""" logits = top_out logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:]) targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:]) cutoff = getattr(self._model_hparams, "video_modality_loss_cutoff", 0.01) return common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, cutoff=cutoff, weights_fn=self.targets_weights_fn)
def get_extra_internal_loss(self, extra_raw_gts, extra_gts, extra_pds): """Hacky code the get the loss on predicted frames from input frames. Recurrent models consume the frames one-by-one. Therefore if there is more than one input frame they also get predicted. T2T only calculates loss on the predicted target frames which means the loss is not being applied on the predicted input frames. This code is to fix this issue. Since the model is not aware of the modality it has to match the pre-porocessing happening in bottom function and therefore this becomes a very hacky code. This code should match the bottom and top and loss of modalities otherwise it will calculate the wrong loss. Args: extra_raw_gts: extra raw ground truth frames. extra_gts: extra normalized ground truth frames. extra_pds: extra predicted frames. Returns: Additional reconstruction loss. Raises: ValueError: in case of unknown modality. """ if self._target_modality == "VideoModalityL2Raw": recon_loss = tf.losses.mean_squared_error(extra_gts, extra_pds) elif self._target_modality == "VideoModality": shape = common_layers.shape_list(extra_pds) updated_shape = shape[:-1] + [3, 256] extra_pds = tf.reshape(extra_pds, updated_shape) # Merge time and batch logits = tf.reshape(extra_pds, [-1] + updated_shape[2:]) targets = extra_raw_gts targets_shape = common_layers.shape_list(targets) targets = tf.reshape(targets, [-1] + targets_shape[2:]) mod = self.hparams.problem_hparams.modality["targets"] numerator, denominator = common_layers.padded_cross_entropy( logits, targets, self.hparams.label_smoothing, cutoff=getattr(self.hparams, "video_modality_loss_cutoff", 0.01), weights_fn=mod.targets_weights_fn) recon_loss = numerator / denominator else: raise ValueError("internal loss only supports specific modalities.") tf.summary.scalar("recon_extra", recon_loss) return recon_loss
def loss(self, top_out, targets): """Average loss over the labels.""" logits = top_out num_labels = tf.shape(targets)[1] logits = tf.tile(logits, [1, num_labels, 1, 1, 1]) xent, weights = common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, weights_fn=self.targets_weights_fn, reduce_sum=False, ) xent = tf.squeeze(xent, [2, 3]) weights = tf.squeeze(weights, [2, 3]) # average loss over all labels loss = tf.reduce_sum(xent, axis=1) weights = tf.reduce_sum(weights, axis=1) loss /= (weights + 1e-8) weights = tf.to_float(tf.greater(weights, 0.)) return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)
def testPaddingCrossEntropyFactored(self): if tf.executing_eagerly(): return # don't run test in Eager mode vocab_size = 19 rows = 5 cols = 4 depth = 11 label_smoothing = 0.1 features = np.random.rand(rows, cols, depth) weights = np.random.rand(vocab_size, depth) labels = np.random.randint(0, vocab_size - 1, size=(rows, cols)) with self.session() as session: features = tf.to_float(features) weights = tf.to_float(weights) labels = tf.to_int32(labels) logits = tf.matmul(tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True) logits = tf.reshape(logits, [rows, cols, vocab_size]) loss_num, loss_den = common_layers.padded_cross_entropy( logits, labels, label_smoothing=label_smoothing, reduce_sum=False) factored_logits = common_layers.FactoredTensor(features, weights) loss_num_f, loss_den_f = common_layers.padded_cross_entropy_factored( factored_logits, labels=labels, label_smoothing=label_smoothing, reduce_sum=False) num, den, num_f, den_f = session.run( [loss_num, loss_den, loss_num_f, loss_den_f]) self.assertEqual(num.shape, (rows, cols)) self.assertEqual(den.shape, (rows, cols)) self.assertEqual(num_f.shape, (rows, cols)) self.assertEqual(den_f.shape, (rows, cols)) self.assertAllClose(num, num_f) self.assertAllClose(den, den_f)
def loss(self, top_out, targets, weights_fn=None, features=None, curriculum=False): """Compute loss numerator and denominator for one shard of output.""" logits = top_out if weights_fn is None: weights_fn = self.targets_weights_fn if curriculum: scaled_weights_fn = lambda *args: (weights_fn(args) * (features[ 'weight'] if features is not None else 1.0)) else: scaled_weights_fn = weights_fn return common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, weights_fn=scaled_weights_fn)
def loss(self, top_out, targets): """Average loss over the labels.""" logits = top_out num_labels = tf.shape(targets)[1] logits = tf.tile(logits, [1, num_labels, 1, 1, 1]) xent, weights = common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, weights_fn=self.targets_weights_fn, reduce_sum=False, ) xent = tf.squeeze(xent, [2, 3]) weights = tf.squeeze(weights, [2, 3]) # average loss over all labels loss = tf.reduce_sum(xent, axis=1) weights = tf.reduce_sum(weights, axis=1) loss /= (weights + 1e-8) weights = tf.to_float(tf.greater(weights, 0.)) return tf.reduce_sum(loss*weights), tf.reduce_sum(weights)
def auxiliary_loss(self, body_output, features, shift): """Auxiliary predict loss. Args: body_output: Tensor with shape [batch_size, decoder_length, hidden_dim]. features: Map of features to the model. Must contain the following: "targets": Target decoder outputs. [batch_size, decoder_length, 1, hidden_dim] shift: int != 0, amount to shift/pad the target sequence. If shift > 0, it represents the number of previous timesteps to reconstruct; if shift < 0, it represents the number of future timesteps to predict. Returns: A 2-tuple of the numerator and denominator of the cross-entropy loss. Raises: ValueError: if features does not contain a targets_raw tensor. """ assert isinstance(shift, int) and shift != 0 name = "reconst_%d" % shift if shift > 0 else "predict_%d" % abs(shift) if features and "targets_raw" in features: targets = features["targets_raw"] targets = common_layers.flatten4d3d(targets) else: raise ValueError( "Feature map must contain a targets_raw tensor.") with tf.variable_scope(name): logits = self.top(body_output, features) labels = shift_and_pad(targets, shift, axis=1) return common_layers.padded_cross_entropy( logits, labels, self._hparams.label_smoothing)
def testPaddingCrossEntropyFactoredGrad(self): vocab_size = 19 rows = 5 cols = 4 depth = 11 label_smoothing = 0.1 features = np.random.rand(rows, cols, depth) weights = np.random.rand(vocab_size, depth) labels = np.random.randint(0, vocab_size - 1, size=(rows, cols)) with self.test_session() as session: features = tf.to_float(features) weights = tf.to_float(weights) labels = tf.to_int32(labels) logits = tf.matmul( tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True) logits = tf.reshape(logits, [rows, cols, vocab_size]) loss_num, loss_den = common_layers.padded_cross_entropy( logits, labels, label_smoothing=label_smoothing, reduce_sum=False) factored_logits = common_layers.FactoredTensor(features, weights) loss_num_factored, loss_den_factored = ( common_layers.padded_cross_entropy_factored( factored_logits, labels=labels, label_smoothing=label_smoothing, reduce_sum=False)) df, dw = tf.gradients(ys=[loss_num, loss_den], xs=[features, weights]) df_factored, dw_factored = tf.gradients( ys=[loss_num_factored, loss_den_factored], xs=[features, weights]) actual_df, actual_dw, actual_df_factored, actual_dw_factored = ( session.run([df, dw, df_factored, dw_factored])) self.assertEqual(actual_df.shape, (rows, cols, depth)) self.assertEqual(actual_dw.shape, (vocab_size, depth)) self.assertEqual(actual_df_factored.shape, (rows, cols, depth)) self.assertEqual(actual_dw_factored.shape, (vocab_size, depth)) self.assertAllClose(actual_df, actual_df_factored) self.assertAllClose(actual_dw, actual_dw_factored)
def testPaddingCrossEntropyFactoredGrad(self): vocab_size = 19 rows = 5 cols = 4 depth = 11 label_smoothing = 0.1 features = np.random.rand(rows, cols, depth) weights = np.random.rand(vocab_size, depth) labels = np.random.randint(0, vocab_size - 1, size=(rows, cols)) with self.test_session() as session: features = tf.to_float(features) weights = tf.to_float(weights) labels = tf.to_int32(labels) logits = tf.matmul( tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True) logits = tf.reshape(logits, [rows, cols, vocab_size]) loss_num, loss_den = common_layers.padded_cross_entropy( logits, labels, label_smoothing=label_smoothing, reduce_sum=False) factored_logits = common_layers.FactoredTensor(features, weights) loss_num_factored, loss_den_factored = ( common_layers.padded_cross_entropy_factored( factored_logits, labels=labels, label_smoothing=label_smoothing, reduce_sum=False)) df, dw = tf.gradients(ys=[loss_num, loss_den], xs=[features, weights]) df_factored, dw_factored = tf.gradients( ys=[loss_num_factored, loss_den_factored], xs=[features, weights]) actual_df, actual_dw, actual_df_factored, actual_dw_factored = ( session.run([df, dw, df_factored, dw_factored])) self.assertEqual(actual_df.shape, (rows, cols, depth)) self.assertEqual(actual_dw.shape, (vocab_size, depth)) self.assertEqual(actual_df_factored.shape, (rows, cols, depth)) self.assertEqual(actual_dw_factored.shape, (vocab_size, depth)) self.assertAllClose(actual_df, actual_df_factored) self.assertAllClose(actual_dw, actual_dw_factored)
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp( tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack( mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing)
def reconstruct_loss(x, gt, hparams, reuse=None): pred = tf.layers.dense(x, hparams.vocab_size, name="softmax", reuse=reuse) xent, w = common_layers.padded_cross_entropy(pred, gt, 0.0) return xent / w
"""Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) return e_x / e_x.sum(axis=-1, keepdims=True) # only difference model_fn = utils.model_builder.build_model_fn( FLAGS.model, problem_names=[FLAGS.problems], train_steps=FLAGS.train_steps, worker_id=FLAGS.worker_id, worker_replicas=FLAGS.worker_replicas, eval_run_autoregressive=FLAGS.eval_run_autoregressive, decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams)) est_spec = model_fn(features, target, mode, hparams) loss, weight = common_layers.padded_cross_entropy( est_spec.predictions['predictions'], target, 0, reduce_sum=False) with tf.variable_scope(tf.get_variable_scope(), reuse=True): beam_out = model_fn(features, target, tf.contrib.learn.ModeKeys.INFER, hparams) sv = tf.train.Supervisor(logdir=FLAGS.output_dir, global_step=tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')) sess = sv.PrepareSession(config=tf.ConfigProto(allow_soft_placement=True)) sv.StartQueueRunners( sess, tf.get_default_graph().get_collection(tf.GraphKeys.QUEUE_RUNNERS))
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) extra_loss = self.get_extra_loss(latent_means=latent_means, latent_stds=latent_stds, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard if self.is_training: self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if self.is_training and hparams.internal_loss: # add the MSE loss for input frames as well. extra_gts = all_frames[1:hparams.video_num_input_frames + 1] extra_gts = common_video.swap_time_and_batch_axes(extra_gts) extra_pds = gen_images[:hparams.video_num_input_frames] extra_pds = common_video.swap_time_and_batch_axes(extra_pds) if self._target_modality == "VideoModalityL2Raw": recon_loss = tf.losses.mean_squared_error(extra_gts, extra_pds) elif self._target_modality == "VideoModality": shape = common_layers.shape_list(extra_pds) updated_shape = shape[:-1] + [3, 256] extra_pds = tf.reshape(extra_pds, updated_shape) # Merge time and batch logits = tf.reshape(extra_pds, [-1] + updated_shape[2:]) targets_shape = common_layers.shape_list( features["targets_raw"]) targets = tf.reshape(features["targets_raw"], [-1] + targets_shape[2:]) mod = self.hparams.problem_hparams.target_modality["targets"] numerator, denominator = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing, cutoff=getattr(hparams, "video_modality_loss_cutoff", 0.01), weights_fn=mod.targets_weights_fn) recon_loss = numerator / denominator else: raise ValueError( "internal loss only supports specific modalities.") tf.summary.scalar("recon_extra", recon_loss) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def reconstruct_loss(x, gt, hparams, reuse=None): pred = tf.layers.dense(x, hparams.vocab_size, name="softmax", reuse=reuse) xent, w = common_layers.padded_cross_entropy(pred, gt, 0.0) return xent / w
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp( tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack( mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing)