def prepare_decoder(self, targets):
        """Prepares targets for transformer decoder."""
        shape = utils.shape_list(targets)
        # sequence should be [batch, seq_length]
        assert len(shape) == 2, "Sequence tensors should be 2-dimensional"
        assert len(self.hparams.query_shape
                   ) == 1, "query shape should be 1-dimensional"

        # Mask random positions
        if self.hparams.target_dropout:
            targets = tf.where(
                tf.random.uniform(shape) < self.hparams.target_dropout,
                tf.zeros_like(targets), targets)
        # Shift positions
        targets = tf.expand_dims(targets, axis=-1)
        targets = utils.right_shift_blockwise_nd(targets,
                                                 self.hparams.query_shape)
        targets = tf.squeeze(targets, axis=-1)
        # Add token embeddings
        targets = utils.get_embeddings(targets=targets,
                                       hidden_size=self.hparams.embedding_dims,
                                       vocab_size=self.vocab_size)
        if self.hparams.dropout:
            targets = tf.nn.dropout(targets, 1 - self.hparams.dropout)
        targets = tf.layers.dense(targets,
                                  self.hidden_size,
                                  activation=None,
                                  name="emb_dense")
        if self.hparams.add_timing_signal:
            targets += utils.get_timing_signal_1d(
                self.hparams.max_target_length, self.hidden_size)
        return targets
Esempio n. 2
0
    def prepare_decoder(self, targets):
        """Prepares targets for transformer decoder."""
        shape = utils.shape_list(targets)
        # image should be [batch, height, width, channels]
        assert len(shape) == 4, "Image tensors should be 4-dimensional"

        # Shift positions
        targets = tf.reshape(targets,
                             [-1] + self.get_shape_for_decoder() + [1])
        targets = utils.right_shift_blockwise_nd(targets,
                                                 self.hparams.query_shape)

        # Add channel embeddings
        targets = tf.reshape(
            targets,
            [-1, self.frame_height, self.frame_width, self.num_channels])
        targets = utils.get_channel_embeddings(io_depth=self.num_channels,
                                               targets=targets,
                                               hidden_size=self.hidden_size)

        # add positional embeddings if needed
        if self.add_positional_emb:
            targets = utils.add_positional_embedding_nd(
                targets,
                max_length=max(self.frame_height, self.frame_width,
                               self.num_channels),
                name="pos_emb")
        targets = tf.reshape(targets, [-1] + self.get_shape_for_decoder() +
                             [self.hidden_size])
        return targets