コード例 #1
0
    def test_log_prob_padding(self):
        scores = torch.tensor(
            [
                [1.0, 2.0, 3.0, 4.0, 5.0],
                [1.0, 2.0, 3.0, 4.0, 5.0],
            ],
            requires_grad=True,
        )
        shape = 2.0
        frechet_sort = FrechetSort(topk=3, shape=shape, log_scores=True)

        # A shorter sequence should have a higher prob
        action = torch.tensor(
            [
                [0, 1, 2, 3, 4],
                [0, 1, 5, 5, 5],
            ],
            dtype=torch.long,
        )
        log_probs = frechet_sort.log_prob(scores, action)
        self.assertLess(log_probs[0], log_probs[1])

        log_probs.sum().backward()
        self.assertGreater(scores.grad.sum(), 0)

        # manually calculating the log prob for the second case
        # 5 is padding, so we remove it here
        s = scores[1][action[1][:2]]
        log_prob = 0.0
        for p in range(2):
            log_prob -= torch.exp((s[p:] - s[p]) * shape).sum().log()

        self.assertAlmostEqual(log_prob, log_probs[1])
コード例 #2
0
    def test_log_prob(self):
        scores = torch.tensor([
            [1.0, 2.0, 3.0, 4.0, 5.0],
            [5.0, 1.0, 2.0, 3.0, 4.0],
        ])
        shape = 2.0
        frechet_sort = FrechetSort(topk=3, shape=shape, log_scores=True)

        # The log-prob should be the same; the last 2 positions don't matter
        action = torch.tensor(
            [
                [0, 1, 2, 3, 4],
                [1, 2, 3, 0, 4],
            ],
            dtype=torch.long,
        )
        log_probs = frechet_sort.log_prob(scores, action)
        self.assertEqual(log_probs[0], log_probs[1])

        action = torch.tensor(
            [
                [0, 1, 2, 3, 4],
                [3, 2, 1, 0, 4],
            ],
            dtype=torch.long,
        )
        log_probs = frechet_sort.log_prob(scores, action)
        self.assertLess(log_probs[0], log_probs[1])

        # manually calculating the log prob for the second case
        s = scores[1][action[1]]
        log_prob = 0.0
        for p in range(3):
            log_prob -= torch.exp((s[p:] - s[p]) * shape).sum().log()

        self.assertAlmostEqual(log_prob, log_probs[1])
コード例 #3
0
    def test_ips_ratio_mean(self, output_arch, shape):
        output_arch = Seq2SlateOutputArch.FRECHET_SORT
        shape = 0.1
        logger.info(f"output arch: {output_arch}")
        logger.info(f"frechet shape: {shape}")

        candidate_num = 5
        candidate_dim = 2
        state_dim = 1
        hidden_size = 8
        device = torch.device("cpu")
        batch_size = 1024
        num_batches = 400
        learning_rate = 0.001
        policy_gradient_interval = 1

        state = torch.zeros(batch_size, state_dim)
        # all data have same candidates
        candidates = torch.randint(
            5, (batch_size, candidate_num, candidate_dim)).float()
        candidates[1:] = candidates[0]
        candidate_scores = torch.sum(candidates, dim=-1)

        seq2slate_params = Seq2SlateParameters(on_policy=False, )
        seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num,
                                                     candidate_dim,
                                                     hidden_size, output_arch,
                                                     device)
        trainer = create_trainer(
            seq2slate_net,
            batch_size,
            learning_rate,
            device,
            seq2slate_params,
            policy_gradient_interval,
        )

        sampler = FrechetSort(shape=shape, topk=candidate_num)
        sum_of_ips_ratio = 0

        for i in range(num_batches):
            sample_outputs = [
                sampler.sample_action(candidate_scores[j:j + 1])
                for j in range(batch_size)
            ]
            action = torch.stack(
                list(map(lambda x: x.action.squeeze(0), sample_outputs)))
            logged_propensity = torch.stack(
                list(map(lambda x: torch.exp(x.log_prob), sample_outputs)))
            batch = rlt.PreprocessedRankingInput.from_input(
                state=state,
                candidates=candidates,
                device=device,
                action=action,
                logged_propensities=logged_propensity,
            )
            model_propensities = torch.exp(
                seq2slate_net(
                    batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs)
            impt_smpl, _ = trainer._compute_impt_smpl(model_propensities,
                                                      logged_propensity)
            sum_of_ips_ratio += torch.mean(impt_smpl).detach().numpy()
            mean_of_ips_ratio = sum_of_ips_ratio / (i + 1)
            logger.info(f"{i}-th batch, mean ips ratio={mean_of_ips_ratio}")

            if i > 100 and np.allclose(mean_of_ips_ratio, 1, atol=0.03):
                return

        raise Exception(
            f"Mean ips ratio {mean_of_ips_ratio} is not close to 1")
コード例 #4
0
    def test_compute_impt_smpl(self, output_arch, clamp_method, clamp_max,
                               shape):
        logger.info(f"output arch: {output_arch}")
        logger.info(f"clamp method: {clamp_method}")
        logger.info(f"clamp max: {clamp_max}")
        logger.info(f"frechet shape: {shape}")

        candidate_num = 5
        candidate_dim = 2
        state_dim = 1
        hidden_size = 32
        device = torch.device("cpu")
        batch_size = 32
        learning_rate = 0.001
        policy_gradient_interval = 1

        candidates = torch.randint(5, (candidate_num, candidate_dim)).float()
        candidate_scores = torch.sum(candidates, dim=1)

        seq2slate_params = Seq2SlateParameters(
            on_policy=False,
            ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=clamp_max),
        )
        seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num,
                                                     candidate_dim,
                                                     hidden_size, output_arch,
                                                     device)
        trainer = create_trainer(
            seq2slate_net,
            batch_size,
            learning_rate,
            device,
            seq2slate_params,
            policy_gradient_interval,
        )

        all_permt = torch.tensor(
            list(permutations(range(candidate_num), candidate_num)))
        sampler = FrechetSort(shape=shape, topk=candidate_num)
        sum_of_logged_propensity = 0
        sum_of_model_propensity = 0
        sum_of_ips_ratio = 0

        for i in range(len(all_permt)):
            sample_action = all_permt[i]
            logged_propensity = torch.exp(
                sampler.log_prob(candidate_scores, sample_action))
            batch = rlt.PreprocessedRankingInput.from_input(
                state=torch.zeros(1, state_dim),
                candidates=candidates.unsqueeze(0),
                device=device,
                action=sample_action.unsqueeze(0),
                logged_propensities=logged_propensity.reshape(1, 1),
            )
            model_propensities = torch.exp(
                seq2slate_net(
                    batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs)
            impt_smpl, clamped_impt_smpl = trainer._compute_impt_smpl(
                model_propensities, logged_propensity)
            if impt_smpl > clamp_max:
                if clamp_method == IPSClampMethod.AGGRESSIVE:
                    npt.asset_allclose(clamped_impt_smpl.detach().numpy(),
                                       0,
                                       rtol=1e-5)
                else:
                    npt.assert_allclose(clamped_impt_smpl.detach().numpy(),
                                        clamp_max,
                                        rtol=1e-5)

            sum_of_model_propensity += model_propensities
            sum_of_logged_propensity += logged_propensity
            sum_of_ips_ratio += model_propensities / logged_propensity
            logger.info(
                f"shape={shape}, sample_action={sample_action}, logged_propensity={logged_propensity},"
                f" model_propensity={model_propensities}")

        logger.info(
            f"shape {shape}, sum_of_logged_propensity={sum_of_logged_propensity}, "
            f"sum_of_model_propensity={sum_of_model_propensity}, "
            f"mean sum_of_ips_ratio={sum_of_ips_ratio / len(all_permt)}")
        npt.assert_allclose(sum_of_logged_propensity.detach().numpy(),
                            1,
                            rtol=1e-5)
        npt.assert_allclose(sum_of_model_propensity.detach().numpy(),
                            1,
                            rtol=1e-5)