def _convert_seq2slate_to_reward_model_format(
            self, input: rlt.PreprocessedRankingInput):
        """
        In the reward model, the transformer decoder should see the full
        sequences; while in seq2slate, the decoder only sees the sequence
        before the last item.
        """
        device = next(self.parameters()).device
        # pyre-fixme[16]: Optional type has no attribute `float_features`.
        batch_size, tgt_seq_len, candidate_dim = input.tgt_out_seq.float_features.shape
        assert self.max_tgt_seq_len == tgt_seq_len

        tgt_tgt_mask = subsequent_mask(tgt_seq_len + 1, device)
        # shape: batch_size, tgt_seq_len + 1, candidate_dim
        tgt_in_seq = torch.cat(
            (
                self.decoder_start_vec.repeat(batch_size, 1, 1),
                input.tgt_out_seq.float_features,
            ),
            dim=1,
        )

        return rlt.PreprocessedRankingInput.from_tensors(
            state=input.state.float_features,
            src_seq=input.src_seq.float_features,
            src_src_mask=input.src_src_mask,
            tgt_in_seq=tgt_in_seq,
            tgt_tgt_mask=tgt_tgt_mask,
        )
Example #2
0
        def process_tgt_seq(action):
            if action is not None:
                _, output_size = action.shape
                # Account for decoder starting symbol and padding symbol
                candidates_augment = torch.cat(
                    (
                        torch.zeros(
                            batch_size, 2, candidate_dim, device=device),
                        candidates,
                    ),
                    dim=1,
                )
                tgt_out_idx = action + 2
                tgt_in_idx = torch.full((batch_size, output_size),
                                        DECODER_START_SYMBOL,
                                        device=device)
                tgt_in_idx[:, 1:] = tgt_out_idx[:, :-1]
                tgt_out_seq = gather(candidates_augment, tgt_out_idx)
                tgt_in_seq = torch.zeros(batch_size,
                                         output_size,
                                         candidate_dim,
                                         device=device)
                tgt_in_seq[:, 1:] = tgt_out_seq[:, :-1]
                tgt_tgt_mask = subsequent_mask(output_size, device)
            else:
                tgt_in_idx = None
                tgt_out_idx = None
                tgt_in_seq = None
                tgt_out_seq = None
                tgt_tgt_mask = None

            return tgt_in_idx, tgt_out_idx, tgt_in_seq, tgt_out_seq, tgt_tgt_mask
Example #3
0
 def test_subsequent_mask(self):
     expect_mask = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]])
     mask = subsequent_mask(3, torch.device("cpu"))
     assert torch.all(torch.eq(mask, expect_mask))
Example #4
0
    def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy):
        """ Decode sequences based on given inputs """
        device = src_seq.device
        batch_size, src_seq_len, candidate_dim = src_seq.shape
        candidate_size = src_seq_len + 2

        # candidate_features is used as look-up table for candidate features.
        # the second dim is src_seq_len + 2 because we also want to include
        # features of start symbol and padding symbol
        candidate_features = torch.zeros(
            batch_size, src_seq_len + 2, candidate_dim, device=device
        )
        # TODO: T62502977 create learnable feature vectors for start symbol
        # and padding symbol
        candidate_features[:, 2:, :] = src_seq

        # memory shape: batch_size, src_seq_len, dim_model
        memory = self.encode(state, src_seq, src_src_mask)

        ranked_per_symbol_probs = torch.zeros(
            batch_size, tgt_seq_len, candidate_size, device=device
        )
        ranked_per_seq_probs = torch.zeros(batch_size, 1)

        if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[
                :, :tgt_seq_len
            ]
            # +2 to account for start symbol and padding symbol
            tgt_out_idx += 2
            # every position has propensity of 1 because we are just using argsort
            ranked_per_symbol_probs = ranked_per_symbol_probs.scatter(
                2, tgt_out_idx.unsqueeze(2), 1.0
            )
            ranked_per_seq_probs[:, :] = 1.0
            return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx

        tgt_in_idx = (
            torch.ones(batch_size, 1, device=device)
            .fill_(self._DECODER_START_SYMBOL)
            .type(torch.long)
        )

        assert greedy is not None
        for l in range(tgt_seq_len):
            tgt_in_seq = (
                candidate_features[
                    torch.arange(batch_size, device=device).repeat_interleave(l + 1),
                    tgt_in_idx.flatten(),
                ]
                .view(batch_size, l + 1, -1)
                .to(device)
            )
            tgt_src_mask = src_src_mask[:, : l + 1, :]
            # shape batch_size, l + 1, candidate_size
            logits = self.decode(
                memory=memory,
                state=state,
                tgt_src_mask=tgt_src_mask,
                tgt_in_seq=tgt_in_seq,
                tgt_tgt_mask=subsequent_mask(l + 1, device),
                tgt_seq_len=l + 1,
            )
            # next candidate shape: batch_size, 1
            # prob shape: batch_size, candidate_size
            next_candidate, prob = self.generator(
                mode=self._DECODE_ONE_STEP_MODE,
                logits=logits,
                tgt_in_idx=tgt_in_idx,
                greedy=greedy,
            )
            ranked_per_symbol_probs[:, l, :] = prob
            tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1)

        # remove the decoder start symbol
        # tgt_out_idx shape: batch_size, tgt_seq_len
        tgt_out_idx = tgt_in_idx[:, 1:]

        ranked_per_seq_probs = per_symbol_to_per_seq_probs(
            ranked_per_symbol_probs, tgt_out_idx
        )

        # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size
        # ranked_per_seq_probs shape: batch_size, 1
        # tgt_out_idx shape: batch_size, tgt_seq_len
        return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx