def test_mask_logits_by_idx(self):
     logits = torch.tensor(
         [
             [
                 [1.0, 2.0, 3.0, 4.0, 5.0],
                 [2.0, 3.0, 4.0, 5.0, 6.0],
                 [3.0, 4.0, 5.0, 6.0, 7.0],
             ],
             [
                 [5.0, 4.0, 3.0, 2.0, 1.0],
                 [6.0, 5.0, 4.0, 3.0, 2.0],
                 [7.0, 6.0, 5.0, 4.0, 3.0],
             ],
         ]
     )
     tgt_in_idx = torch.tensor(
         [[DECODER_START_SYMBOL, 2, 3], [DECODER_START_SYMBOL, 4, 3]]
     )
     masked_logits = mask_logits_by_idx(logits, tgt_in_idx)
     expected_logits = torch.tensor(
         [
             [
                 [float("-inf"), float("-inf"), 3.0, 4.0, 5.0],
                 [float("-inf"), float("-inf"), float("-inf"), 5.0, 6.0],
                 [float("-inf"), float("-inf"), float("-inf"), float("-inf"), 7.0],
             ],
             [
                 [float("-inf"), float("-inf"), 3.0, 2.0, 1.0],
                 [float("-inf"), float("-inf"), 4.0, 3.0, float("-inf")],
                 [float("-inf"), float("-inf"), 5.0, float("-inf"), float("-inf")],
             ],
         ]
     )
     assert torch.all(torch.eq(masked_logits, expected_logits))
Example #2
0
    def decode(self, memory, state, tgt_in_idx, tgt_in_seq):
        # memory is the output of the encoder, the attention of each input symbol
        # memory shape: batch_size, src_seq_len, dim_model
        # tgt_in_idx shape: batch_size, tgt_seq_len
        # tgt_seq shape: batch_size, tgt_seq_len, dim_candidate
        batch_size, src_seq_len, _ = memory.shape
        _, tgt_seq_len = tgt_in_idx.shape
        candidate_size = src_seq_len + 2

        if self.output_arch == Seq2SlateOutputArch.FRECHET_SORT:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            logits = torch.zeros(batch_size, tgt_seq_len, candidate_size).to(
                encoder_scores.device
            )
            logits[:, :, :2] = float("-inf")
            logits[:, :, 2:] = encoder_scores.repeat(1, tgt_seq_len).reshape(
                batch_size, tgt_seq_len, src_seq_len
            )
            logits = mask_logits_by_idx(logits, tgt_in_idx)
            probs = torch.softmax(logits, dim=2)
        elif self.output_arch == Seq2SlateOutputArch.AUTOREGRESSIVE:
            # candidate_embed shape: batch_size, tgt_seq_len, dim_model/2
            candidate_embed = self.candidate_embedder(tgt_in_seq)
            # state_embed: batch_size, dim_model/2
            state_embed = self.state_embedder(state)
            # state_embed: batch_size, tgt_seq_len, dim_model/2
            state_embed = state_embed.repeat(1, tgt_seq_len).reshape(
                batch_size, tgt_seq_len, -1
            )
            # tgt_embed: batch_size, tgt_seq_len, dim_model
            tgt_embed = self.positional_encoding_decoder(
                torch.cat((state_embed, candidate_embed), dim=2)
            )
            # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len
            # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len
            tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
                memory, tgt_in_idx, self.num_heads
            )
            # output of decoder is probabilities over symbols.
            # shape: batch_size, tgt_seq_len, candidate_size
            probs = self.decoder(tgt_embed, memory, tgt_src_mask, tgt_tgt_mask)
        else:
            raise NotImplementedError()

        return probs
Example #3
0
    def _log_probs(self, logits, tgt_in_idx, mode):
        """
        Return the log probability distribution at each decoding step

        :param logits: logits of decoder outputs. Shape: batch_size, seq_len, candidate_size
        :param tgt_idx: the indices of candidates in decoder input sequences.
            The first symbol is always DECODER_START_SYMBOL.
            Shape: batch_size, seq_len
        """
        assert mode in (
            Seq2SlateMode.PER_SEQ_LOG_PROB_MODE,
            Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE,
        )
        logits = mask_logits_by_idx(logits, tgt_in_idx)
        # log_probs shape: batch_size, seq_len, candidate_size
        log_probs = F.log_softmax(logits / self.temperature, dim=2)
        return log_probs