def prepare_decoder(self, targets): """Prepares targets for transformer decoder.""" shape = utils.shape_list(targets) # sequence should be [batch, seq_length] assert len(shape) == 2, "Sequence tensors should be 2-dimensional" assert len(self.hparams.query_shape ) == 1, "query shape should be 1-dimensional" # Mask random positions if self.hparams.target_dropout: targets = tf.where( tf.random.uniform(shape) < self.hparams.target_dropout, tf.zeros_like(targets), targets) # Shift positions targets = tf.expand_dims(targets, axis=-1) targets = utils.right_shift_blockwise_nd(targets, self.hparams.query_shape) targets = tf.squeeze(targets, axis=-1) # Add token embeddings targets = utils.get_embeddings(targets=targets, hidden_size=self.hparams.embedding_dims, vocab_size=self.vocab_size) if self.hparams.dropout: targets = tf.nn.dropout(targets, 1 - self.hparams.dropout) targets = tf.layers.dense(targets, self.hidden_size, activation=None, name="emb_dense") if self.hparams.add_timing_signal: targets += utils.get_timing_signal_1d( self.hparams.max_target_length, self.hidden_size) return targets
def prepare_decoder(self, targets): """Prepares targets for transformer decoder.""" shape = utils.shape_list(targets) # image should be [batch, height, width, channels] assert len(shape) == 4, "Image tensors should be 4-dimensional" # Shift positions targets = tf.reshape(targets, [-1] + self.get_shape_for_decoder() + [1]) targets = utils.right_shift_blockwise_nd(targets, self.hparams.query_shape) # Add channel embeddings targets = tf.reshape( targets, [-1, self.frame_height, self.frame_width, self.num_channels]) targets = utils.get_channel_embeddings(io_depth=self.num_channels, targets=targets, hidden_size=self.hidden_size) # add positional embeddings if needed if self.add_positional_emb: targets = utils.add_positional_embedding_nd( targets, max_length=max(self.frame_height, self.frame_width, self.num_channels), name="pos_emb") targets = tf.reshape(targets, [-1] + self.get_shape_for_decoder() + [self.hidden_size]) return targets