Esempio n. 1
0
    def test_pytorch_decoder_mask(self):
        batch_size = 3
        src_seq_len = 4
        num_heads = 2

        memory = torch.randn(batch_size, src_seq_len, num_heads)
        tgt_in_idx = torch.tensor([[1, 2, 3], [1, 4, 2], [1, 5, 4]]).long()
        tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
            memory, tgt_in_idx, num_heads)

        expected_tgt_tgt_mask = (torch.tensor([
            [False, True, True],
            [False, False, True],
            [False, False, False],
        ], ).unsqueeze(0).repeat(batch_size * num_heads, 1, 1))
        expected_tgt_src_mask = torch.tensor([
            [
                [False, False, False, False],
                [True, False, False, False],
                [True, True, False, False],
            ],
            [
                [False, False, False, False],
                [False, False, True, False],
                [True, False, True, False],
            ],
            [
                [False, False, False, False],
                [False, False, False, True],
                [False, False, True, True],
            ],
        ]).repeat_interleave(num_heads, dim=0)
        assert torch.all(tgt_tgt_mask == expected_tgt_tgt_mask)
        assert torch.all(tgt_src_mask == expected_tgt_src_mask)
Esempio n. 2
0
    def decode(self, memory, state, tgt_in_idx, tgt_in_seq):
        # memory is the output of the encoder, the attention of each input symbol
        # memory shape: batch_size, src_seq_len, dim_model
        # tgt_in_idx shape: batch_size, tgt_seq_len
        # tgt_seq shape: batch_size, tgt_seq_len, dim_candidate
        batch_size, src_seq_len, _ = memory.shape
        _, tgt_seq_len = tgt_in_idx.shape
        candidate_size = src_seq_len + 2

        if self.output_arch == Seq2SlateOutputArch.FRECHET_SORT:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            logits = torch.zeros(batch_size, tgt_seq_len, candidate_size).to(
                encoder_scores.device
            )
            logits[:, :, :2] = float("-inf")
            logits[:, :, 2:] = encoder_scores.repeat(1, tgt_seq_len).reshape(
                batch_size, tgt_seq_len, src_seq_len
            )
            logits = mask_logits_by_idx(logits, tgt_in_idx)
            probs = torch.softmax(logits, dim=2)
        elif self.output_arch == Seq2SlateOutputArch.AUTOREGRESSIVE:
            # candidate_embed shape: batch_size, tgt_seq_len, dim_model/2
            candidate_embed = self.candidate_embedder(tgt_in_seq)
            # state_embed: batch_size, dim_model/2
            state_embed = self.state_embedder(state)
            # state_embed: batch_size, tgt_seq_len, dim_model/2
            state_embed = state_embed.repeat(1, tgt_seq_len).reshape(
                batch_size, tgt_seq_len, -1
            )
            # tgt_embed: batch_size, tgt_seq_len, dim_model
            tgt_embed = self.positional_encoding_decoder(
                torch.cat((state_embed, candidate_embed), dim=2)
            )
            # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len
            # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len
            tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
                memory, tgt_in_idx, self.num_heads
            )
            # output of decoder is probabilities over symbols.
            # shape: batch_size, tgt_seq_len, candidate_size
            probs = self.decoder(tgt_embed, memory, tgt_src_mask, tgt_tgt_mask)
        else:
            raise NotImplementedError()

        return probs
Esempio n. 3
0
    def _log_probs(
        self,
        state,
        src_seq,
        tgt_in_seq,
        src_src_mask,
        tgt_tgt_mask,
        tgt_in_idx,
        tgt_out_idx,
        mode,
    ):
        """
        Compute log of generative probabilities of given tgt sequences
        (used for REINFORCE training)
        """
        # encoder_output shape: batch_size, src_seq_len, dim_model
        encoder_output = self.encode(state, src_seq, src_src_mask)

        tgt_seq_len = tgt_in_seq.shape[1]
        src_seq_len = src_seq.shape[1]
        assert tgt_seq_len <= src_seq_len

        # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len
        # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len
        tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
            encoder_output, tgt_in_idx, self.num_heads)
        # decoder_probs shape: batch_size, tgt_seq_len, candidate_size
        decoder_probs = self.decode(
            memory=encoder_output,
            state=state,
            tgt_src_mask=tgt_src_mask,
            tgt_in_idx=tgt_in_idx,
            tgt_in_seq=tgt_in_seq,
            tgt_tgt_mask=tgt_tgt_mask,
        )
        # log_probs shape:
        # if mode == PER_SEQ_LOG_PROB_MODE: batch_size, 1
        # if mode == PER_SYMBOL_LOG_PROB_DIST_MODE: batch_size, tgt_seq_len, candidate_size
        if mode == Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE:
            per_symbol_log_probs = torch.log(
                torch.clamp(decoder_probs, min=EPSILON))
            return per_symbol_log_probs

        per_seq_log_probs = torch.log(
            per_symbol_to_per_seq_probs(decoder_probs, tgt_out_idx))
        return per_seq_log_probs
Esempio n. 4
0
    def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy):
        """ Decode sequences based on given inputs """
        device = src_seq.device
        batch_size, src_seq_len, candidate_dim = src_seq.shape
        candidate_size = src_seq_len + 2

        # candidate_features is used as look-up table for candidate features.
        # the second dim is src_seq_len + 2 because we also want to include
        # features of start symbol and padding symbol
        candidate_features = torch.zeros(batch_size,
                                         src_seq_len + 2,
                                         candidate_dim,
                                         device=device)
        # TODO: T62502977 create learnable feature vectors for start symbol
        # and padding symbol
        candidate_features[:, 2:, :] = src_seq

        # memory shape: batch_size, src_seq_len, dim_model
        memory = self.encode(state, src_seq, src_src_mask)

        ranked_per_symbol_probs = torch.zeros(batch_size,
                                              tgt_seq_len,
                                              candidate_size,
                                              device=device)
        ranked_per_seq_probs = torch.zeros(batch_size, 1)

        if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            tgt_out_idx = torch.argsort(encoder_scores, dim=1,
                                        descending=True)[:, :tgt_seq_len]
            # +2 to account for start symbol and padding symbol
            tgt_out_idx += 2
            # every position has propensity of 1 because we are just using argsort
            ranked_per_symbol_probs = ranked_per_symbol_probs.scatter(
                2, tgt_out_idx.unsqueeze(2), 1.0)
            ranked_per_seq_probs[:, :] = 1.0
            return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx

        tgt_in_idx = (torch.ones(batch_size, 1, device=device).fill_(
            self._DECODER_START_SYMBOL).type(torch.long))

        assert greedy is not None
        for l in range(tgt_seq_len):
            tgt_in_seq = gather(candidate_features, tgt_in_idx)
            tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
                memory, tgt_in_idx, self.num_heads)
            # shape batch_size, l + 1, candidate_size
            probs = self.decode(
                memory=memory,
                state=state,
                tgt_src_mask=tgt_src_mask,
                tgt_in_idx=tgt_in_idx,
                tgt_in_seq=tgt_in_seq,
                tgt_tgt_mask=tgt_tgt_mask,
            )
            # next candidate shape: batch_size, 1
            # prob shape: batch_size, candidate_size
            next_candidate, next_candidate_sample_prob = self.generator(
                probs, greedy)
            ranked_per_symbol_probs[:, l, :] = next_candidate_sample_prob
            tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1)

        # remove the decoder start symbol
        # tgt_out_idx shape: batch_size, tgt_seq_len
        tgt_out_idx = tgt_in_idx[:, 1:]

        ranked_per_seq_probs = per_symbol_to_per_seq_probs(
            ranked_per_symbol_probs, tgt_out_idx)

        # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size
        # ranked_per_seq_probs shape: batch_size, 1
        # tgt_out_idx shape: batch_size, tgt_seq_len
        return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx