예제 #1
0
    def test_per_symbol_to_per_seq_log_probs(self):
        """
        Test per_symbol_to_per_seq_log_probs method
        """
        batch_size = 1
        seq_len = 3
        candidate_size = seq_len + 2

        tgt_out_idx = torch.tensor([[0, 2, 1]]) + 2
        per_symbol_log_probs = torch.randn(batch_size, seq_len, candidate_size)
        per_symbol_log_probs[0, :, :2] = float("-inf")
        per_symbol_log_probs[0, 1, 2] = float("-inf")
        per_symbol_log_probs[0, 2, 2] = float("-inf")
        per_symbol_log_probs[0, 2, 4] = float("-inf")
        per_symbol_log_probs = F.log_softmax(per_symbol_log_probs, dim=2)

        expect_per_seq_log_probs = (per_symbol_log_probs[0, 0, 2] +
                                    per_symbol_log_probs[0, 1, 4] +
                                    per_symbol_log_probs[0, 2, 3])
        computed_per_seq_log_probs = per_symbol_to_per_seq_log_probs(
            per_symbol_log_probs, tgt_out_idx)
        np.testing.assert_allclose(expect_per_seq_log_probs,
                                   computed_per_seq_log_probs,
                                   atol=0.001,
                                   rtol=0.0)
예제 #2
0
    def _decoder_logits_to_log_probs(self, logits, tgt_in_idx, tgt_out_idx,
                                     mode):
        """
        :param logits: the logits from the decoder, with shape:
            (batch_size, seq_len, candidate_size)
        :param tgt_in_idx: input idx to the decoder, the first symbol is
            always the DECODER_START_SYMBOL. Shape: batch_size x seq_len
        :param tgt_out_idx: output idx of the decoder. Shape: batch_size x seq_len
        :param mode: return log prob distribution per symbol or reduce them per sequence
        """
        assert mode in (
            self._PER_SEQ_LOG_PROB_MODE,
            self._PER_SYMBOL_LOG_PROB_DIST_MODE,
        )
        # per_symbol_log_probs: log probability distribution of each symbol
        # shape: batch_size, seq_len, candidate_size
        per_symbol_log_probs = self.generator(mode=mode,
                                              logits=logits,
                                              tgt_in_idx=tgt_in_idx)

        if mode == self._PER_SYMBOL_LOG_PROB_DIST_MODE:
            return per_symbol_log_probs

        # shape: batch_size, 1
        return per_symbol_to_per_seq_log_probs(per_symbol_log_probs,
                                               tgt_out_idx)
예제 #3
0
    def train(self, training_batch: rlt.PreprocessedRankingInput):
        assert type(training_batch) is rlt.PreprocessedRankingInput

        per_symbol_log_probs = self.seq2slate_net(
            training_batch,
            mode=Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE).log_probs
        per_seq_log_probs = per_symbol_to_per_seq_log_probs(
            per_symbol_log_probs, training_batch.tgt_out_idx)
        assert per_symbol_log_probs.requires_grad and per_seq_log_probs.requires_grad
        # pyre-fixme[16]: `Optional` has no attribute `shape`.
        assert per_seq_log_probs.shape == training_batch.tgt_out_probs.shape

        if not self.parameters.on_policy:
            importance_sampling = (torch.exp(per_seq_log_probs) /
                                   training_batch.tgt_out_probs)
            importance_sampling = ips_clamp(importance_sampling,
                                            self.parameters.ips_clamp)
        else:
            importance_sampling = (torch.exp(per_seq_log_probs) /
                                   torch.exp(per_seq_log_probs).detach())
        assert importance_sampling.requires_grad

        # pyre-fixme[6]: Expected `Tensor` for 1st param but got
        #  `Optional[torch.Tensor]`.
        labels = self._transform_label(training_batch.tgt_out_idx)
        assert not labels.requires_grad

        batch_size, max_tgt_seq_len = training_batch.tgt_out_idx.shape
        # batch_loss shape: batch_size x max_tgt_seq_len
        batch_loss = (
            torch.sum(self.kl_div_loss(per_symbol_log_probs, labels), dim=2) *
            training_batch.position_reward)
        # weighted_batch_loss shape: batch_size, 1
        weighted_batch_loss = torch.sum(
            1.0 / torch.log(
                torch.arange(1, 1 + max_tgt_seq_len,
                             device=batch_loss.device).float() + 1.0) *
            batch_loss,
            dim=1,
            keepdim=True,
        )
        loss = 1.0 / batch_size * torch.sum(
            importance_sampling * weighted_batch_loss)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        loss = loss.detach().cpu().numpy()
        per_symbol_log_probs = per_symbol_log_probs.detach()
        self.minibatch += 1
        if self.minibatch % self.print_interval == 0:
            logger.info(f"{self.minibatch} batch: loss={loss}")

        return {"per_symbol_log_probs": per_symbol_log_probs, "sl": loss}
예제 #4
0
    def test_seq2slate_transformer_propensity_computation(
        self, output_arch, temperature
    ):
        """
        Test propensity computation of seq2slate net
        """
        candidate_num = 4
        candidate_dim = 2
        hidden_size = 32
        all_perm = torch.tensor(
            list(permutations(torch.arange(candidate_num), candidate_num))
        )
        batch_size = len(all_perm)
        device = torch.device("cpu")

        seq2slate_net = create_seq2slate_net(
            MODEL_TRANSFORMER,
            candidate_num,
            candidate_dim,
            hidden_size,
            output_arch,
            temperature,
            device,
        )
        batch = create_batch(
            batch_size,
            candidate_num,
            candidate_dim,
            device,
            ON_POLICY,
            diverse_input=False,
        )
        batch = rlt.PreprocessedRankingInput.from_input(
            state=batch.state.float_features,
            candidates=batch.src_seq.float_features,
            device=device,
            action=all_perm,
        )
        per_symbol_log_prob = seq2slate_net(
            batch, mode=Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE
        ).log_probs
        per_seq_log_prob = seq2slate_net(
            batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE
        ).log_probs
        per_seq_log_prob_computed = per_symbol_to_per_seq_log_probs(
            per_symbol_log_prob, all_perm + 2
        )
        # probabilities of two modes should match
        np.testing.assert_allclose(
            per_seq_log_prob, per_seq_log_prob_computed, atol=0.00001
        )
        # probabilities of all possible permutations should sum up to 1
        np.testing.assert_allclose(
            torch.sum(torch.exp(per_seq_log_prob)), 1.0, atol=0.00001
        )