예제 #1
0
    def test_seq2slate_transformer_propensity_computation(
        self, output_arch, temperature
    ):
        """
        Test propensity computation of seq2slate net
        """
        candidate_num = 4
        candidate_dim = 2
        hidden_size = 32
        all_perm = torch.tensor(
            list(permutations(torch.arange(candidate_num), candidate_num))
        )
        batch_size = len(all_perm)
        device = torch.device("cpu")

        seq2slate_net = create_seq2slate_net(
            MODEL_TRANSFORMER,
            candidate_num,
            candidate_dim,
            hidden_size,
            output_arch,
            temperature,
            device,
        )
        batch = create_batch(
            batch_size,
            candidate_num,
            candidate_dim,
            device,
            ON_POLICY,
            diverse_input=False,
        )
        batch = rlt.PreprocessedRankingInput.from_input(
            state=batch.state.float_features,
            candidates=batch.src_seq.float_features,
            device=device,
            action=all_perm,
        )
        per_symbol_log_prob = seq2slate_net(
            batch, mode=Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE
        ).log_probs
        per_seq_log_prob = seq2slate_net(
            batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE
        ).log_probs
        per_seq_log_prob_computed = per_symbol_to_per_seq_log_probs(
            per_symbol_log_prob, all_perm + 2
        )
        # probabilities of two modes should match
        np.testing.assert_allclose(
            per_seq_log_prob, per_seq_log_prob_computed, atol=0.00001
        )
        # probabilities of all possible permutations should sum up to 1
        np.testing.assert_allclose(
            torch.sum(torch.exp(per_seq_log_prob)), 1.0, atol=0.00001
        )
예제 #2
0
    def test_seq2slate_transformer_onplicy_basic_logic(self, output_arch,
                                                       temperature):
        """
        Test basic logic of seq2slate on policy sampling
        """
        device = torch.device("cpu")
        candidate_num = 4
        candidate_dim = 2
        batch_size = 4096
        hidden_size = 32
        seq2slate_net = create_seq2slate_net(
            MODEL_TRANSFORMER,
            candidate_num,
            candidate_dim,
            hidden_size,
            output_arch,
            temperature,
            device,
        )
        batch = create_batch(batch_size,
                             candidate_num,
                             candidate_dim,
                             device,
                             diverse_input=False)

        action_to_propensity_map = {}
        action_count = defaultdict(int)
        total_count = 0
        for i in range(50):
            model_propensity, model_action = rank_on_policy(seq2slate_net,
                                                            batch,
                                                            candidate_num,
                                                            greedy=False)
            for propensity, action in zip(model_propensity, model_action):
                action_str = ",".join(map(str, action.numpy().tolist()))

                # Same action always leads to same propensity
                if action_to_propensity_map.get(action_str) is None:
                    action_to_propensity_map[action_str] = float(propensity)
                else:
                    np.testing.assert_allclose(
                        action_to_propensity_map[action_str],
                        float(propensity),
                        atol=0.001,
                        rtol=0.0,
                    )

                action_count[action_str] += 1
                total_count += 1

            logger.info(f"Finish {i} round, {total_count} data counts")

        # Check action distribution
        for action_str, count in action_count.items():
            empirical_propensity = count / total_count
            computed_propensity = action_to_propensity_map[action_str]
            logger.info(
                f"action={action_str}, empirical propensity={empirical_propensity}, "
                f"computed propensity={computed_propensity}")
            np.testing.assert_allclose(computed_propensity,
                                       empirical_propensity,
                                       atol=0.01,
                                       rtol=0.0)