예제 #1
0
    def test_per_symbol_to_per_seq_probs(self):
        batch_size = 1
        seq_len = 3
        candidate_size = seq_len + 2

        tgt_out_idx = torch.tensor([[0, 2, 1]]) + 2
        per_symbol_log_probs = torch.randn(batch_size, seq_len, candidate_size)
        per_symbol_log_probs[0, :, :2] = float("-inf")
        per_symbol_log_probs[0, 1, 2] = float("-inf")
        per_symbol_log_probs[0, 2, 2] = float("-inf")
        per_symbol_log_probs[0, 2, 4] = float("-inf")
        per_symbol_log_probs = F.log_softmax(per_symbol_log_probs, dim=2)
        per_symbol_probs = torch.exp(per_symbol_log_probs)

        expect_per_seq_probs = (
            per_symbol_probs[0, 0, 2]
            * per_symbol_probs[0, 1, 4]
            * per_symbol_probs[0, 2, 3]
        )
        computed_per_seq_probs = per_symbol_to_per_seq_probs(
            per_symbol_probs, tgt_out_idx
        )
        np.testing.assert_allclose(
            expect_per_seq_probs, computed_per_seq_probs, atol=0.001, rtol=0.0
        )
예제 #2
0
    def _log_probs(
        self,
        state: torch.Tensor,
        src_seq: torch.Tensor,
        tgt_in_seq: torch.Tensor,
        tgt_in_idx: torch.Tensor,
        tgt_out_idx: torch.Tensor,
        mode: str,
    ) -> Seq2SlateTransformerOutput:
        """
        Compute log of generative probabilities of given tgt sequences
        (used for REINFORCE training)
        """
        # encoder_output shape: batch_size, src_seq_len, dim_model
        encoder_output = self.encode(state, src_seq)

        tgt_seq_len = tgt_in_seq.shape[1]
        src_seq_len = src_seq.shape[1]
        assert tgt_seq_len <= src_seq_len

        # decoder_probs shape: batch_size, tgt_seq_len, candidate_size
        decoder_probs = self.decode(
            memory=encoder_output,
            state=state,
            tgt_in_idx=tgt_in_idx,
            tgt_in_seq=tgt_in_seq,
        )
        # log_probs shape:
        # if mode == PER_SEQ_LOG_PROB_MODE: batch_size, 1
        # if mode == PER_SYMBOL_LOG_PROB_DIST_MODE: batch_size, tgt_seq_len, candidate_size
        if mode == self._PER_SYMBOL_LOG_PROB_DIST_MODE:
            per_symbol_log_probs = torch.log(torch.clamp(decoder_probs, min=1e-40))
            return Seq2SlateTransformerOutput(
                ranked_per_symbol_probs=None,
                ranked_per_seq_probs=None,
                ranked_tgt_out_idx=None,
                per_symbol_log_probs=per_symbol_log_probs,
                per_seq_log_probs=None,
                encoder_scores=None,
            )

        per_seq_log_probs = torch.log(
            per_symbol_to_per_seq_probs(decoder_probs, tgt_out_idx)
        )
        return Seq2SlateTransformerOutput(
            ranked_per_symbol_probs=None,
            ranked_per_seq_probs=None,
            ranked_tgt_out_idx=None,
            per_symbol_log_probs=None,
            per_seq_log_probs=per_seq_log_probs,
            encoder_scores=None,
        )
예제 #3
0
    def _rank(
        self, state: torch.Tensor, src_seq: torch.Tensor, tgt_seq_len: int, greedy: bool
    ) -> Seq2SlateTransformerOutput:
        """Decode sequences based on given inputs"""
        device = src_seq.device
        batch_size, src_seq_len, candidate_dim = src_seq.shape
        candidate_size = src_seq_len + 2

        # candidate_features is used as look-up table for candidate features.
        # the second dim is src_seq_len + 2 because we also want to include
        # features of start symbol and padding symbol
        candidate_features = torch.zeros(
            batch_size, src_seq_len + 2, candidate_dim, device=device
        )
        # TODO: T62502977 create learnable feature vectors for start symbol
        # and padding symbol
        candidate_features[:, 2:, :] = src_seq

        # memory shape: batch_size, src_seq_len, dim_model
        memory = self.encode(state, src_seq)

        if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            tgt_out_idx, ranked_per_symbol_probs = self._encoder_rank(
                memory, tgt_seq_len
            )
        elif self.output_arch == Seq2SlateOutputArch.FRECHET_SORT and greedy:
            # greedy decoding for non-autoregressive decoder
            tgt_out_idx, ranked_per_symbol_probs = self._greedy_rank(
                state, memory, candidate_features, tgt_seq_len
            )
        else:
            assert greedy is not None
            # autoregressive decoding
            tgt_out_idx, ranked_per_symbol_probs = self._autoregressive_rank(
                state, memory, candidate_features, tgt_seq_len, greedy
            )
        # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size
        # ranked_per_seq_probs shape: batch_size, 1
        ranked_per_seq_probs = per_symbol_to_per_seq_probs(
            ranked_per_symbol_probs, tgt_out_idx
        )

        # tgt_out_idx shape: batch_size, tgt_seq_len
        return Seq2SlateTransformerOutput(
            ranked_per_symbol_probs=ranked_per_symbol_probs,
            ranked_per_seq_probs=ranked_per_seq_probs,
            ranked_tgt_out_idx=tgt_out_idx,
            per_symbol_log_probs=self._OUTPUT_PLACEHOLDER,
            per_seq_log_probs=self._OUTPUT_PLACEHOLDER,
            encoder_scores=self._OUTPUT_PLACEHOLDER,
        )
예제 #4
0
    def _log_probs(
        self,
        state,
        src_seq,
        tgt_in_seq,
        src_src_mask,
        tgt_tgt_mask,
        tgt_in_idx,
        tgt_out_idx,
        mode,
    ):
        """
        Compute log of generative probabilities of given tgt sequences
        (used for REINFORCE training)
        """
        # encoder_output shape: batch_size, src_seq_len, dim_model
        encoder_output = self.encode(state, src_seq, src_src_mask)

        tgt_seq_len = tgt_in_seq.shape[1]
        src_seq_len = src_seq.shape[1]
        assert tgt_seq_len <= src_seq_len

        # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len
        # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len
        tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
            encoder_output, tgt_in_idx, self.num_heads)
        # decoder_probs shape: batch_size, tgt_seq_len, candidate_size
        decoder_probs = self.decode(
            memory=encoder_output,
            state=state,
            tgt_src_mask=tgt_src_mask,
            tgt_in_idx=tgt_in_idx,
            tgt_in_seq=tgt_in_seq,
            tgt_tgt_mask=tgt_tgt_mask,
        )
        # log_probs shape:
        # if mode == PER_SEQ_LOG_PROB_MODE: batch_size, 1
        # if mode == PER_SYMBOL_LOG_PROB_DIST_MODE: batch_size, tgt_seq_len, candidate_size
        if mode == Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE:
            per_symbol_log_probs = torch.log(
                torch.clamp(decoder_probs, min=EPSILON))
            return per_symbol_log_probs

        per_seq_log_probs = torch.log(
            per_symbol_to_per_seq_probs(decoder_probs, tgt_out_idx))
        return per_seq_log_probs
예제 #5
0
    def _rank(
        self, state: torch.Tensor, src_seq: torch.Tensor, tgt_seq_len: int, greedy: bool
    ) -> Seq2SlateTransformerOutput:
        """ Decode sequences based on given inputs """
        device = src_seq.device
        batch_size, src_seq_len, candidate_dim = src_seq.shape
        candidate_size = src_seq_len + 2

        # candidate_features is used as look-up table for candidate features.
        # the second dim is src_seq_len + 2 because we also want to include
        # features of start symbol and padding symbol
        candidate_features = torch.zeros(
            batch_size, src_seq_len + 2, candidate_dim, device=device
        )
        # TODO: T62502977 create learnable feature vectors for start symbol
        # and padding symbol
        candidate_features[:, 2:, :] = src_seq

        # memory shape: batch_size, src_seq_len, dim_model
        memory = self.encode(state, src_seq)

        ranked_per_symbol_probs = torch.zeros(
            batch_size, tgt_seq_len, candidate_size, device=device
        )
        ranked_per_seq_probs = torch.zeros(batch_size, 1)

        if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[
                :, :tgt_seq_len
            ]
            # +2 to account for start symbol and padding symbol
            tgt_out_idx += 2
            # every position has propensity of 1 because we are just using argsort
            ranked_per_symbol_probs = ranked_per_symbol_probs.scatter(
                2, tgt_out_idx.unsqueeze(2), 1.0
            )
            ranked_per_seq_probs[:, :] = 1.0
            return Seq2SlateTransformerOutput(
                ranked_per_symbol_probs=ranked_per_symbol_probs,
                ranked_per_seq_probs=ranked_per_seq_probs,
                ranked_tgt_out_idx=tgt_out_idx,
                per_symbol_log_probs=self._OUTPUT_PLACEHOLDER,
                per_seq_log_probs=self._OUTPUT_PLACEHOLDER,
                encoder_scores=self._OUTPUT_PLACEHOLDER,
            )

        tgt_in_idx = (
            torch.ones(batch_size, 1, device=device)
            .fill_(self._DECODER_START_SYMBOL)
            .long()
        )

        assert greedy is not None
        for l in range(tgt_seq_len):
            tgt_in_seq = gather(candidate_features, tgt_in_idx)

            # shape batch_size, l + 1, candidate_size
            probs = self.decode(
                memory=memory,
                state=state,
                tgt_in_idx=tgt_in_idx,
                tgt_in_seq=tgt_in_seq,
            )
            # next candidate shape: batch_size, 1
            # prob shape: batch_size, candidate_size
            next_candidate, next_candidate_sample_prob = self.generator(probs, greedy)
            ranked_per_symbol_probs[:, l, :] = next_candidate_sample_prob
            tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1)

        # remove the decoder start symbol
        # tgt_out_idx shape: batch_size, tgt_seq_len
        tgt_out_idx = tgt_in_idx[:, 1:]

        ranked_per_seq_probs = per_symbol_to_per_seq_probs(
            ranked_per_symbol_probs, tgt_out_idx
        )

        # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size
        # ranked_per_seq_probs shape: batch_size, 1
        # tgt_out_idx shape: batch_size, tgt_seq_len
        return Seq2SlateTransformerOutput(
            ranked_per_symbol_probs=ranked_per_symbol_probs,
            ranked_per_seq_probs=ranked_per_seq_probs,
            ranked_tgt_out_idx=tgt_out_idx,
            per_symbol_log_probs=self._OUTPUT_PLACEHOLDER,
            per_seq_log_probs=self._OUTPUT_PLACEHOLDER,
            encoder_scores=self._OUTPUT_PLACEHOLDER,
        )
예제 #6
0
파일: seq2slate.py 프로젝트: saonam/ReAgent
    def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy):
        """ Decode sequences based on given inputs """
        device = src_seq.device
        batch_size, src_seq_len, candidate_dim = src_seq.shape
        candidate_size = src_seq_len + 2

        # candidate_features is used as look-up table for candidate features.
        # the second dim is src_seq_len + 2 because we also want to include
        # features of start symbol and padding symbol
        candidate_features = torch.zeros(
            batch_size, src_seq_len + 2, candidate_dim, device=device
        )
        # TODO: T62502977 create learnable feature vectors for start symbol
        # and padding symbol
        candidate_features[:, 2:, :] = src_seq

        # memory shape: batch_size, src_seq_len, dim_model
        memory = self.encode(state, src_seq, src_src_mask)

        ranked_per_symbol_probs = torch.zeros(
            batch_size, tgt_seq_len, candidate_size, device=device
        )
        ranked_per_seq_probs = torch.zeros(batch_size, 1)

        if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[
                :, :tgt_seq_len
            ]
            # +2 to account for start symbol and padding symbol
            tgt_out_idx += 2
            # every position has propensity of 1 because we are just using argsort
            ranked_per_symbol_probs = ranked_per_symbol_probs.scatter(
                2, tgt_out_idx.unsqueeze(2), 1.0
            )
            ranked_per_seq_probs[:, :] = 1.0
            return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx

        tgt_in_idx = (
            torch.ones(batch_size, 1, device=device)
            .fill_(self._DECODER_START_SYMBOL)
            .type(torch.long)
        )

        assert greedy is not None
        for l in range(tgt_seq_len):
            tgt_in_seq = (
                candidate_features[
                    torch.arange(batch_size, device=device).repeat_interleave(l + 1),
                    tgt_in_idx.flatten(),
                ]
                .view(batch_size, l + 1, -1)
                .to(device)
            )
            tgt_src_mask = src_src_mask[:, : l + 1, :]
            # shape batch_size, l + 1, candidate_size
            logits = self.decode(
                memory=memory,
                state=state,
                tgt_src_mask=tgt_src_mask,
                tgt_in_seq=tgt_in_seq,
                tgt_tgt_mask=subsequent_mask(l + 1, device),
                tgt_seq_len=l + 1,
            )
            # next candidate shape: batch_size, 1
            # prob shape: batch_size, candidate_size
            next_candidate, prob = self.generator(
                mode=self._DECODE_ONE_STEP_MODE,
                logits=logits,
                tgt_in_idx=tgt_in_idx,
                greedy=greedy,
            )
            ranked_per_symbol_probs[:, l, :] = prob
            tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1)

        # remove the decoder start symbol
        # tgt_out_idx shape: batch_size, tgt_seq_len
        tgt_out_idx = tgt_in_idx[:, 1:]

        ranked_per_seq_probs = per_symbol_to_per_seq_probs(
            ranked_per_symbol_probs, tgt_out_idx
        )

        # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size
        # ranked_per_seq_probs shape: batch_size, 1
        # tgt_out_idx shape: batch_size, tgt_seq_len
        return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx