def body(self, features, decode_step=None, cache=None, decoding_stats=None): targets = tf.to_int32(features["targets"]) if self.is_decode: targets = self.process_partial_targets_decoding(targets) decoder_input = self.prepare_decoder(targets) extra_losses = [] if not self.hparams.unconditional: # condition on class label if not self.is_inputs_class_label: raise ValueError("SparseImagetransformer can only condition on " "'inputs' feature if it represents class label.") inputs = features["inputs"] # Embed class here rather than in bottom(). if inputs.dtype not in [tf.int32, tf.int64]: raise ValueError("Do not embed 'inputs' before body(). " "Found dtype=%s." % inputs.dtype) inputs = utils.get_embeddings( targets=inputs, vocab_size=self.inputs_vocab_size, hidden_size=self.hidden_size, name="class_conditional_embedding") # Add class embedding to each spatial location. batch_size = tf.shape(targets)[0] hidden_size = tf.shape(inputs)[-1] num_middle_dims = len(decoder_input.shape) - 2 decoder_input += tf.reshape(inputs, [batch_size] + [1] * num_middle_dims + [hidden_size]) decoder_output = utils.transformer_decoder_layers( inputs=decoder_input, num_layers=self.num_decoder_layers, hparams=self.hparams, decode_step=decode_step, losses=extra_losses, cache=cache, name="decoder", decoding_stats=decoding_stats) logits = self.produce_output(decoder_output) # Return logits as-is in decoding mode if self.is_decode: return logits # Produce a summary of the output. results = self.multinomial_squeeze(logits, self.hparams.sampling_temp) results = tf.reshape( results, [-1, self.frame_height, self.frame_width, self.num_channels]) if utils.is_xla_compiled(): _IMGS["predictions"] = results # Prepare loss. loss_dict = {} if extra_losses: loss_dict["extra_loss"] = tf.add_n(extra_losses) return logits, loss_dict
def body(self, features, decode_step=None, cache=None, decoding_stats=None, add_summary=True): encoder_output = None extra_losses = [] padding_bias = None if not self.hparams.fast_decode: decode_step = None if "inputs" in features: inputs = features["inputs"] # remove the last two dimensions that are always 1. inputs = tf.reshape( inputs, utils.shape_list(inputs)[:2] + [self.hidden_size]) # Padding bias only used for seq2seq models. padding_bias = utils.embedding_to_padding(inputs) # Mask random positions shape = utils.shape_list(inputs) if self.hparams.input_dropout: inputs = tf.where( tf.random.uniform(shape) < self.hparams.input_dropout, tf.zeros_like(inputs), inputs) if self.hparams.add_timing_signal: inputs += utils.get_timing_signal_1d(self.hparams.max_length, self.hidden_size) if cache is not None and -1 in cache: encoder_output = cache[-1] else: encoder_output = utils.transformer_encoder_layers( inputs=inputs, num_layers=self.num_encoder_layers, hparams=self.hparams, losses=extra_losses, name="encoder", token_bias=features.get("token_bias_inputs"), padding_bias=padding_bias) if cache is not None and -1 not in cache: cache[-1] = encoder_output targets = tf.to_int32(features["targets"]) # remove the last two dimensions that are always 1. targets = tf.reshape(targets, utils.shape_list(targets)[:2]) # Clamp targets to max_target_length targets = targets[:, :self.hparams.max_target_length] if self.is_decode: targets = self.process_partial_targets_decoding(targets) decoder_input = self.prepare_decoder(targets) decoder_output = utils.transformer_decoder_layers( inputs=decoder_input, num_layers=self.num_decoder_layers, hparams=self.hparams, encoder_output=encoder_output, decode_step=decode_step, losses=extra_losses, cache=cache, name="decoder", decoding_stats=decoding_stats, token_bias_inputs=features.get("token_bias_inputs"), token_bias_targets=features.get("token_bias_targets"), padding_bias=padding_bias) logits = self.produce_output(decoder_output) # Return logits as-is in decoding mode if self.is_decode: return logits # Add cross entropy loss one_hot_targets = tf.one_hot(tf.cast(targets, dtype=tf.int32), self.vocab_size) x_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_targets, logits=logits) weights = tf.to_float(tf.not_equal(targets, 0)) loss = tf.reduce_sum(x_entropy * weights) / tf.reduce_sum(weights) if add_summary: tf.summary.scalar("losses/weight", tf.reduce_sum(weights)) tf.summary.scalar("losses/x_entropy", tf.reduce_sum(x_entropy * weights)) loss_dict = {"training": loss} if extra_losses: loss_dict["extra_loss"] = tf.add_n(extra_losses) # hack for T2T metrics logits = tf.reshape( logits, utils.shape_list(logits)[:2] + [1, 1] + utils.shape_list(logits)[-1:]) return logits, loss_dict