Пример #1
0
 def test_ips_clamp(self):
     importance_sampling = torch.tensor([0.5, 0.3, 3.0, 10.0, 40.0])
     assert torch.all(ips_clamp(importance_sampling, None) == importance_sampling)
     assert torch.all(
         ips_clamp(importance_sampling, IPSClamp(IPSClampMethod.AGGRESSIVE, 3.0))
         == torch.tensor([0.5, 0.3, 3.0, 0.0, 0.0])
     )
     assert torch.all(
         ips_clamp(importance_sampling, IPSClamp(IPSClampMethod.UNIVERSAL, 3.0))
         == torch.tensor([0.5, 0.3, 3.0, 3.0, 3.0])
     )
Пример #2
0
    def train(self, training_batch: rlt.PreprocessedRankingInput):
        assert type(training_batch) is rlt.PreprocessedRankingInput

        per_symbol_log_probs = self.seq2slate_net(
            training_batch,
            mode=Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE).log_probs
        per_seq_log_probs = per_symbol_to_per_seq_log_probs(
            per_symbol_log_probs, training_batch.tgt_out_idx)
        assert per_symbol_log_probs.requires_grad and per_seq_log_probs.requires_grad
        # pyre-fixme[16]: `Optional` has no attribute `shape`.
        assert per_seq_log_probs.shape == training_batch.tgt_out_probs.shape

        if not self.parameters.on_policy:
            importance_sampling = (torch.exp(per_seq_log_probs) /
                                   training_batch.tgt_out_probs)
            importance_sampling = ips_clamp(importance_sampling,
                                            self.parameters.ips_clamp)
        else:
            importance_sampling = (torch.exp(per_seq_log_probs) /
                                   torch.exp(per_seq_log_probs).detach())
        assert importance_sampling.requires_grad

        # pyre-fixme[6]: Expected `Tensor` for 1st param but got
        #  `Optional[torch.Tensor]`.
        labels = self._transform_label(training_batch.tgt_out_idx)
        assert not labels.requires_grad

        batch_size, max_tgt_seq_len = training_batch.tgt_out_idx.shape
        # batch_loss shape: batch_size x max_tgt_seq_len
        batch_loss = (
            torch.sum(self.kl_div_loss(per_symbol_log_probs, labels), dim=2) *
            training_batch.position_reward)
        # weighted_batch_loss shape: batch_size, 1
        weighted_batch_loss = torch.sum(
            1.0 / torch.log(
                torch.arange(1, 1 + max_tgt_seq_len,
                             device=batch_loss.device).float() + 1.0) *
            batch_loss,
            dim=1,
            keepdim=True,
        )
        loss = 1.0 / batch_size * torch.sum(
            importance_sampling * weighted_batch_loss)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        loss = loss.detach().cpu().numpy()
        per_symbol_log_probs = per_symbol_log_probs.detach()
        self.minibatch += 1
        if self.minibatch % self.print_interval == 0:
            logger.info(f"{self.minibatch} batch: loss={loss}")

        return {"per_symbol_log_probs": per_symbol_log_probs, "sl": loss}
Пример #3
0
    def _compute_impt_smpl(
            self, model_propensities,
            logged_propensities) -> Tuple[torch.Tensor, torch.Tensor]:
        logged_propensities = logged_propensities.reshape(-1, 1)
        assert (model_propensities.shape == logged_propensities.shape
                and len(model_propensities.shape) == 2
                and model_propensities.shape[1] == 1
                ), f"{model_propensities.shape} {logged_propensities.shape}"

        impt_smpl = model_propensities / logged_propensities
        clamped_impt_smpl = ips_clamp(impt_smpl, self.parameters.ips_clamp)
        return impt_smpl, clamped_impt_smpl
Пример #4
0
    def test_seq2slate_trainer_off_policy_with_clamp(self, clamp_method, output_arch):
        batch_size = 32
        state_dim = 2
        candidate_num = 15
        candidate_dim = 4
        hidden_size = 16
        learning_rate = 1.0
        device = torch.device("cpu")
        policy_gradient_interval = 1
        seq2slate_params = Seq2SlateParameters(
            on_policy=False,
            ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=0.3),
        )

        seq2slate_net = create_seq2slate_transformer(
            state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device
        )
        seq2slate_net_copy = copy.deepcopy(seq2slate_net)
        trainer = create_trainer(
            seq2slate_net,
            batch_size,
            learning_rate,
            device,
            seq2slate_params,
            policy_gradient_interval,
        )
        batch = create_off_policy_batch(
            seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, device
        )

        for _ in range(policy_gradient_interval):
            trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch))

        # manual compute gradient
        ranked_per_seq_probs = torch.exp(
            seq2slate_net_copy(
                batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE
            ).log_probs
        )
        logger.info(f"ips ratio={ranked_per_seq_probs / batch.tgt_out_probs}")
        loss = -(
            torch.mean(
                ips_clamp(
                    ranked_per_seq_probs / batch.tgt_out_probs,
                    seq2slate_params.ips_clamp,
                )
                * batch.slate_reward
            )
        )
        loss.backward()
        self.assert_correct_gradient(
            seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate
        )
    def test_seq2slate_trainer_off_policy_with_clamp(self, clamp_method,
                                                     output_arch):
        batch_size = 32
        state_dim = 2
        candidate_num = 15
        candidate_dim = 4
        hidden_size = 16
        learning_rate = 1.0
        device = torch.device("cpu")
        policy_gradient_interval = 1
        seq2slate_params = Seq2SlateParameters(
            on_policy=False,
            ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=0.3),
        )

        seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num,
                                                     candidate_dim,
                                                     hidden_size, output_arch)
        seq2slate_net_copy = copy.deepcopy(seq2slate_net)
        trainer = create_trainer(
            seq2slate_net,
            learning_rate,
            seq2slate_params,
            policy_gradient_interval,
        )
        batch = create_off_policy_batch(seq2slate_net, batch_size, state_dim,
                                        candidate_num, candidate_dim, device)

        training_data = DataLoader([batch], collate_fn=lambda x: x[0])
        pl_trainer = pl.Trainer(max_epochs=policy_gradient_interval,
                                logger=False)
        pl_trainer.fit(trainer, training_data)

        # manual compute gradient
        ranked_per_seq_probs = torch.exp(
            seq2slate_net_copy(
                batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs)
        logger.info(f"ips ratio={ranked_per_seq_probs / batch.tgt_out_probs}")
        loss = -(torch.mean(
            ips_clamp(
                ranked_per_seq_probs / batch.tgt_out_probs,
                seq2slate_params.ips_clamp,
            ) * batch.slate_reward))
        loss.backward()
        self.assert_correct_gradient(seq2slate_net_copy, seq2slate_net,
                                     policy_gradient_interval, learning_rate)