def prepare_decoder(self, targets): """Prepares targets for transformer decoder.""" shape = utils.shape_list(targets) # sequence should be [batch, seq_length] assert len(shape) == 2, "Sequence tensors should be 2-dimensional" assert len(self.hparams.query_shape ) == 1, "query shape should be 1-dimensional" # Mask random positions if self.hparams.target_dropout: targets = tf.where( tf.random.uniform(shape) < self.hparams.target_dropout, tf.zeros_like(targets), targets) # Shift positions targets = tf.expand_dims(targets, axis=-1) targets = utils.right_shift_blockwise_nd(targets, self.hparams.query_shape) targets = tf.squeeze(targets, axis=-1) # Add token embeddings targets = utils.get_embeddings(targets=targets, hidden_size=self.hparams.embedding_dims, vocab_size=self.vocab_size) if self.hparams.dropout: targets = tf.nn.dropout(targets, 1 - self.hparams.dropout) targets = tf.layers.dense(targets, self.hidden_size, activation=None, name="emb_dense") if self.hparams.add_timing_signal: targets += utils.get_timing_signal_1d( self.hparams.max_target_length, self.hidden_size) return targets
def body(self, features, decode_step=None, cache=None, decoding_stats=None): targets = tf.to_int32(features["targets"]) if self.is_decode: targets = self.process_partial_targets_decoding(targets) decoder_input = self.prepare_decoder(targets) extra_losses = [] if not self.hparams.unconditional: # condition on class label if not self.is_inputs_class_label: raise ValueError("SparseImagetransformer can only condition on " "'inputs' feature if it represents class label.") inputs = features["inputs"] # Embed class here rather than in bottom(). if inputs.dtype not in [tf.int32, tf.int64]: raise ValueError("Do not embed 'inputs' before body(). " "Found dtype=%s." % inputs.dtype) inputs = utils.get_embeddings( targets=inputs, vocab_size=self.inputs_vocab_size, hidden_size=self.hidden_size, name="class_conditional_embedding") # Add class embedding to each spatial location. batch_size = tf.shape(targets)[0] hidden_size = tf.shape(inputs)[-1] num_middle_dims = len(decoder_input.shape) - 2 decoder_input += tf.reshape(inputs, [batch_size] + [1] * num_middle_dims + [hidden_size]) decoder_output = utils.transformer_decoder_layers( inputs=decoder_input, num_layers=self.num_decoder_layers, hparams=self.hparams, decode_step=decode_step, losses=extra_losses, cache=cache, name="decoder", decoding_stats=decoding_stats) logits = self.produce_output(decoder_output) # Return logits as-is in decoding mode if self.is_decode: return logits # Produce a summary of the output. results = self.multinomial_squeeze(logits, self.hparams.sampling_temp) results = tf.reshape( results, [-1, self.frame_height, self.frame_width, self.num_channels]) if utils.is_xla_compiled(): _IMGS["predictions"] = results # Prepare loss. loss_dict = {} if extra_losses: loss_dict["extra_loss"] = tf.add_n(extra_losses) return logits, loss_dict