def test_ips_clamp(self):
     importance_sampling = torch.tensor([0.5, 0.3, 3.0, 10.0, 40.0])
     assert torch.all(
         ips_clamp(importance_sampling, None) == importance_sampling)
     assert torch.all(
         ips_clamp(importance_sampling,
                   IPSClamp(IPSClampMethod.AGGRESSIVE, 3.0)) ==
         torch.tensor([0.5, 0.3, 3.0, 0.0, 0.0]))
     assert torch.all(
         ips_clamp(importance_sampling,
                   IPSClamp(IPSClampMethod.UNIVERSAL, 3.0)) == torch.tensor(
                       [0.5, 0.3, 3.0, 3.0, 3.0]))
예제 #2
0
    def test_seq2slate_trainer_off_policy_with_clamp(self, clamp_method,
                                                     output_arch):
        batch_size = 32
        state_dim = 2
        candidate_num = 15
        candidate_dim = 4
        hidden_size = 16
        learning_rate = 1.0
        device = torch.device("cpu")
        policy_gradient_interval = 1
        seq2slate_params = Seq2SlateParameters(
            on_policy=False,
            ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=0.3),
        )

        seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num,
                                                     candidate_dim,
                                                     hidden_size, output_arch,
                                                     device)
        seq2slate_net_copy = copy.deepcopy(seq2slate_net)
        trainer = create_trainer(
            seq2slate_net,
            batch_size,
            learning_rate,
            device,
            seq2slate_params,
            policy_gradient_interval,
        )
        batch = create_off_policy_batch(seq2slate_net, batch_size, state_dim,
                                        candidate_num, candidate_dim, device)

        for _ in range(policy_gradient_interval):
            trainer.train(batch)

        # manual compute gradient
        ranked_per_seq_probs = torch.exp(
            seq2slate_net_copy(
                batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs)
        logger.info(f"ips ratio={ranked_per_seq_probs / batch.tgt_out_probs}")
        loss = -(torch.mean(
            ips_clamp(
                ranked_per_seq_probs / batch.tgt_out_probs,
                seq2slate_params.ips_clamp,
            ) * batch.slate_reward))
        loss.backward()
        self.assert_correct_gradient(seq2slate_net_copy, seq2slate_net,
                                     policy_gradient_interval, learning_rate)
    def test_compute_impt_smpl(self, output_arch, clamp_method, clamp_max,
                               shape):
        logger.info(f"output arch: {output_arch}")
        logger.info(f"clamp method: {clamp_method}")
        logger.info(f"clamp max: {clamp_max}")
        logger.info(f"frechet shape: {shape}")

        candidate_num = 5
        candidate_dim = 2
        state_dim = 1
        hidden_size = 32
        device = torch.device("cpu")
        learning_rate = 0.001
        policy_gradient_interval = 1

        candidates = torch.randint(5, (candidate_num, candidate_dim)).float()
        candidate_scores = torch.sum(candidates, dim=1)

        seq2slate_params = Seq2SlateParameters(
            on_policy=False,
            ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=clamp_max),
        )
        seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num,
                                                     candidate_dim,
                                                     hidden_size, output_arch)
        trainer = create_trainer(
            seq2slate_net,
            learning_rate,
            seq2slate_params,
            policy_gradient_interval,
        )

        all_permt = torch.tensor(
            list(permutations(range(candidate_num), candidate_num)))
        sampler = FrechetSort(shape=shape, topk=candidate_num)
        sum_of_logged_propensity = 0
        sum_of_model_propensity = 0
        sum_of_ips_ratio = 0

        for i in range(len(all_permt)):
            sample_action = all_permt[i]
            logged_propensity = torch.exp(
                sampler.log_prob(candidate_scores, sample_action))
            batch = rlt.PreprocessedRankingInput.from_input(
                state=torch.zeros(1, state_dim),
                candidates=candidates.unsqueeze(0),
                device=device,
                action=sample_action.unsqueeze(0),
                logged_propensities=logged_propensity.reshape(1, 1),
            )
            model_propensities = torch.exp(
                seq2slate_net(
                    batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs)
            impt_smpl, clamped_impt_smpl = trainer._compute_impt_smpl(
                model_propensities, logged_propensity)
            if impt_smpl > clamp_max:
                if clamp_method == IPSClampMethod.AGGRESSIVE:
                    npt.asset_allclose(clamped_impt_smpl.detach().numpy(),
                                       0,
                                       rtol=1e-5)
                else:
                    npt.assert_allclose(clamped_impt_smpl.detach().numpy(),
                                        clamp_max,
                                        rtol=1e-5)

            sum_of_model_propensity += model_propensities
            sum_of_logged_propensity += logged_propensity
            sum_of_ips_ratio += model_propensities / logged_propensity
            logger.info(
                f"shape={shape}, sample_action={sample_action}, logged_propensity={logged_propensity},"
                f" model_propensity={model_propensities}")

        logger.info(
            f"shape {shape}, sum_of_logged_propensity={sum_of_logged_propensity}, "
            f"sum_of_model_propensity={sum_of_model_propensity}, "
            f"mean sum_of_ips_ratio={sum_of_ips_ratio / len(all_permt)}")
        npt.assert_allclose(sum_of_logged_propensity.detach().numpy(),
                            1,
                            rtol=1e-5)
        npt.assert_allclose(sum_of_model_propensity.detach().numpy(),
                            1,
                            rtol=1e-5)