def prepare_decoder(self, targets): """Prepares targets for transformer decoder.""" shape = utils.shape_list(targets) # sequence should be [batch, seq_length] assert len(shape) == 2, "Sequence tensors should be 2-dimensional" assert len(self.hparams.query_shape ) == 1, "query shape should be 1-dimensional" # Mask random positions if self.hparams.target_dropout: targets = tf.where( tf.random.uniform(shape) < self.hparams.target_dropout, tf.zeros_like(targets), targets) # Shift positions targets = tf.expand_dims(targets, axis=-1) targets = utils.right_shift_blockwise_nd(targets, self.hparams.query_shape) targets = tf.squeeze(targets, axis=-1) # Add token embeddings targets = utils.get_embeddings(targets=targets, hidden_size=self.hparams.embedding_dims, vocab_size=self.vocab_size) if self.hparams.dropout: targets = tf.nn.dropout(targets, 1 - self.hparams.dropout) targets = tf.layers.dense(targets, self.hidden_size, activation=None, name="emb_dense") if self.hparams.add_timing_signal: targets += utils.get_timing_signal_1d( self.hparams.max_target_length, self.hidden_size) return targets
def body(self, features, decode_step=None, cache=None, decoding_stats=None, add_summary=True): encoder_output = None extra_losses = [] padding_bias = None if not self.hparams.fast_decode: decode_step = None if "inputs" in features: inputs = features["inputs"] # remove the last two dimensions that are always 1. inputs = tf.reshape( inputs, utils.shape_list(inputs)[:2] + [self.hidden_size]) # Padding bias only used for seq2seq models. padding_bias = utils.embedding_to_padding(inputs) # Mask random positions shape = utils.shape_list(inputs) if self.hparams.input_dropout: inputs = tf.where( tf.random.uniform(shape) < self.hparams.input_dropout, tf.zeros_like(inputs), inputs) if self.hparams.add_timing_signal: inputs += utils.get_timing_signal_1d(self.hparams.max_length, self.hidden_size) if cache is not None and -1 in cache: encoder_output = cache[-1] else: encoder_output = utils.transformer_encoder_layers( inputs=inputs, num_layers=self.num_encoder_layers, hparams=self.hparams, losses=extra_losses, name="encoder", token_bias=features.get("token_bias_inputs"), padding_bias=padding_bias) if cache is not None and -1 not in cache: cache[-1] = encoder_output targets = tf.to_int32(features["targets"]) # remove the last two dimensions that are always 1. targets = tf.reshape(targets, utils.shape_list(targets)[:2]) # Clamp targets to max_target_length targets = targets[:, :self.hparams.max_target_length] if self.is_decode: targets = self.process_partial_targets_decoding(targets) decoder_input = self.prepare_decoder(targets) decoder_output = utils.transformer_decoder_layers( inputs=decoder_input, num_layers=self.num_decoder_layers, hparams=self.hparams, encoder_output=encoder_output, decode_step=decode_step, losses=extra_losses, cache=cache, name="decoder", decoding_stats=decoding_stats, token_bias_inputs=features.get("token_bias_inputs"), token_bias_targets=features.get("token_bias_targets"), padding_bias=padding_bias) logits = self.produce_output(decoder_output) # Return logits as-is in decoding mode if self.is_decode: return logits # Add cross entropy loss one_hot_targets = tf.one_hot(tf.cast(targets, dtype=tf.int32), self.vocab_size) x_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_targets, logits=logits) weights = tf.to_float(tf.not_equal(targets, 0)) loss = tf.reduce_sum(x_entropy * weights) / tf.reduce_sum(weights) if add_summary: tf.summary.scalar("losses/weight", tf.reduce_sum(weights)) tf.summary.scalar("losses/x_entropy", tf.reduce_sum(x_entropy * weights)) loss_dict = {"training": loss} if extra_losses: loss_dict["extra_loss"] = tf.add_n(extra_losses) # hack for T2T metrics logits = tf.reshape( logits, utils.shape_list(logits)[:2] + [1, 1] + utils.shape_list(logits)[-1:]) return logits, loss_dict