Ejemplo n.º 1
0
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     assert (len(batch.state.shape) == 2
             ), f"{batch.state.shape} is not (batch_size, state_dim)."
     batch_size, _ = batch.state.shape
     action, next_action = one_hot_actions(self.num_actions, batch.action,
                                           batch.next_action,
                                           batch.terminal)
     possible_actions = get_possible_actions_for_gym(
         batch_size, self.num_actions)
     possible_next_actions = possible_actions.clone()
     possible_actions_mask = torch.ones((batch_size, self.num_actions))
     possible_next_actions_mask = possible_actions_mask.clone()
     return rlt.ParametricDqnInput(
         state=rlt.FeatureData(float_features=batch.state),
         action=rlt.FeatureData(float_features=action),
         next_state=rlt.FeatureData(float_features=batch.next_state),
         next_action=rlt.FeatureData(float_features=next_action),
         possible_actions=possible_actions,
         possible_actions_mask=possible_actions_mask,
         possible_next_actions=possible_next_actions,
         possible_next_actions_mask=possible_next_actions_mask,
         reward=batch.reward,
         not_terminal=not_terminal,
         step=None,
         time_diff=None,
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
Ejemplo n.º 2
0
 def forward(self, batch: Dict[str,
                               torch.Tensor]) -> rlt.ParametricDqnInput:
     batch = batch_to_device(batch, self.device)
     # first preprocess state and action
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"])
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"],
         batch["next_state_features_presence"])
     preprocessed_action = self.action_preprocessor(
         batch["action"], batch["action_presence"])
     preprocessed_next_action = self.action_preprocessor(
         batch["next_action"], batch["next_action_presence"])
     return rlt.ParametricDqnInput(
         state=rlt.FeatureData(preprocessed_state),
         next_state=rlt.FeatureData(preprocessed_next_state),
         action=rlt.FeatureData(preprocessed_action),
         next_action=rlt.FeatureData(preprocessed_next_action),
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=batch["not_terminal"].unsqueeze(1),
         possible_actions=batch["possible_actions"],
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions=batch["possible_next_actions"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
Ejemplo n.º 3
0
 def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.DiscreteDqnInput:
     batch = batch_to_device(batch, self.device)
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"])
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"],
         batch["next_state_features_presence"])
     # not terminal iff at least one possible for next action
     not_terminal = batch["possible_next_actions_mask"].max(
         dim=1)[0].float()
     action = F.one_hot(batch["action"].to(torch.int64), self.num_actions)
     # next action can potentially have value self.num_action if not available
     next_action = F.one_hot(batch["next_action"].to(torch.int64),
                             self.num_actions + 1)[:, :self.num_actions]
     return rlt.DiscreteDqnInput(
         state=rlt.FeatureData(preprocessed_state),
         next_state=rlt.FeatureData(preprocessed_next_state),
         action=action,
         next_action=next_action,
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=not_terminal.unsqueeze(1),
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
Ejemplo n.º 4
0
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     # normalize actions
     action = rescale_actions(
         batch.action,
         new_min=self.train_low,
         new_max=self.train_high,
         prev_min=self.action_low,
         prev_max=self.action_high,
     )
     # only normalize non-terminal
     non_terminal_indices = (batch.terminal == 0).squeeze(1)
     next_action = torch.zeros_like(action)
     next_action[non_terminal_indices] = rescale_actions(
         batch.next_action[non_terminal_indices],
         new_min=self.train_low,
         new_max=self.train_high,
         prev_min=self.action_low,
         prev_max=self.action_high,
     )
     dict_batch = {
         InputColumn.STATE_FEATURES: batch.state,
         InputColumn.NEXT_STATE_FEATURES: batch.next_state,
         InputColumn.ACTION: action,
         InputColumn.NEXT_ACTION: next_action,
         InputColumn.REWARD: batch.reward,
         InputColumn.NOT_TERMINAL: not_terminal,
         InputColumn.STEP: None,
         InputColumn.TIME_DIFF: None,
         InputColumn.EXTRAS: rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     }
     has_candidate_features = False
     try:
         dict_batch.update(
             {
                 InputColumn.CANDIDATE_FEATURES: batch.doc,
                 InputColumn.NEXT_CANDIDATE_FEATURES: batch.next_doc,
             }
         )
         has_candidate_features = True
     except AttributeError:
         pass
     output = rlt.PolicyNetworkInput.from_dict(dict_batch)
     if has_candidate_features:
         output.state = rlt._embed_states(output.state)
         output.next_state = rlt._embed_states(output.next_state)
     return output
Ejemplo n.º 5
0
    def __call__(self, batch):
        not_terminal = 1.0 - batch.terminal.float()
        action, next_action = one_hot_actions(self.num_actions, batch.action,
                                              batch.next_action,
                                              batch.terminal)
        if self.trainer_preprocessor is not None:
            state = self.trainer_preprocessor(batch.state)
            next_state = self.trainer_preprocessor(batch.next_state)
        else:
            state = rlt.FeatureData(float_features=batch.state)
            next_state = rlt.FeatureData(float_features=batch.next_state)

        try:
            possible_actions_mask = batch.possible_actions_mask.float()
        except AttributeError:
            possible_actions_mask = torch.ones_like(action).float()

        try:
            possible_next_actions_mask = batch.next_possible_actions_mask.float(
            )
        except AttributeError:
            possible_next_actions_mask = torch.ones_like(next_action).float()

        return rlt.DiscreteDqnInput(
            state=state,
            action=action,
            next_state=next_state,
            next_action=next_action,
            possible_actions_mask=possible_actions_mask,
            possible_next_actions_mask=possible_next_actions_mask,
            reward=batch.reward,
            not_terminal=not_terminal,
            step=None,
            time_diff=None,
            extras=rlt.ExtraData(
                mdp_id=None,
                sequence_number=None,
                action_probability=batch.log_prob.exp(),
                max_num_actions=None,
                metrics=None,
            ),
        )
Ejemplo n.º 6
0
    def test_seq2slate_eval_data_page(self):
        """
        Create 3 slate ranking logs and evaluate using Direct Method, Inverse
        Propensity Scores, and Doubly Robust.

        The logs are as follows:
        state: [1, 0, 0], [0, 1, 0], [0, 0, 1]
        indices in logged slates: [3, 2], [3, 2], [3, 2]
        model output indices: [2, 3], [3, 2], [2, 3]
        logged reward: 4, 5, 7
        logged propensities: 0.2, 0.5, 0.4
        predicted rewards on logged slates: 2, 4, 6
        predicted rewards on model outputted slates: 1, 4, 5
        predicted propensities: 0.4, 0.3, 0.7

        When eval_greedy=True:

        Direct Method uses the predicted rewards on model outputted slates.
        Thus the result is expected to be (1 + 4 + 5) / 3

        Inverse Propensity Scores would scale the reward by 1.0 / logged propensities
        whenever the model output slate matches with the logged slate.
        Since only the second log matches with the model output, the IPS result
        is expected to be 5 / 0.5 / 3

        Doubly Robust is the sum of the direct method result and propensity-scaled
        reward difference; the latter is defined as:
        1.0 / logged_propensities * (logged reward - predicted reward on logged slate)
         * Indicator(model slate == logged slate)
        Since only the second logged slate matches with the model outputted slate,
        the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3


        When eval_greedy=False:

        Only Inverse Propensity Scores would be accurate. Because it would be too
        expensive to compute all possible slates' propensities and predicted rewards
        for Direct Method.

        The expected IPS = (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3
        """
        batch_size = 3
        state_dim = 3
        src_seq_len = 2
        tgt_seq_len = 2
        candidate_dim = 2

        reward_net = FakeSeq2SlateRewardNetwork()
        seq2slate_net = FakeSeq2SlateTransformerNet()

        src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1)
        tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]])
        tgt_out_seq = src_seq[
            torch.arange(batch_size).repeat_interleave(tgt_seq_len),
            tgt_out_idx.flatten() - 2, ].reshape(batch_size, tgt_seq_len,
                                                 candidate_dim)

        ptb = rlt.PreprocessedRankingInput(
            state=rlt.FeatureData(float_features=torch.eye(state_dim)),
            src_seq=rlt.FeatureData(float_features=src_seq),
            tgt_out_seq=rlt.FeatureData(float_features=tgt_out_seq),
            src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len),
            tgt_out_idx=tgt_out_idx,
            tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]),
            slate_reward=torch.tensor([4.0, 5.0, 7.0]),
            extras=rlt.ExtraData(
                sequence_number=torch.tensor([0, 0, 0]),
                mdp_id=np.array(["0", "1", "2"]),
            ),
        )

        edp = EvaluationDataPage.create_from_tensors_seq2slate(
            seq2slate_net, reward_net, ptb, eval_greedy=True)
        logger.info(
            "---------- Start evaluating eval_greedy=True -----------------")
        doubly_robust_estimator = OPEstimatorAdapter(DoublyRobustEstimator())
        dm_estimator = OPEstimatorAdapter(DMEstimator())
        ips_estimator = OPEstimatorAdapter(IPSEstimator())
        switch_estimator = OPEstimatorAdapter(SwitchEstimator())
        switch_dr_estimator = OPEstimatorAdapter(SwitchDREstimator())

        doubly_robust = doubly_robust_estimator.estimate(edp)
        inverse_propensity = ips_estimator.estimate(edp)
        direct_method = dm_estimator.estimate(edp)

        # Verify that Switch with low exponent is equivalent to IPS
        switch_ips = switch_estimator.estimate(edp, exp_base=1)
        # Verify that Switch with no candidates is equivalent to DM
        switch_dm = switch_estimator.estimate(edp, candidates=0)
        # Verify that SwitchDR with low exponent is equivalent to DR
        switch_dr_dr = switch_dr_estimator.estimate(edp, exp_base=1)
        # Verify that SwitchDR with no candidates is equivalent to DM
        switch_dr_dm = switch_dr_estimator.estimate(edp, candidates=0)

        logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}")

        avg_logged_reward = (4 + 5 + 7) / 3
        self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6)
        self.assertAlmostEqual(direct_method.normalized,
                               direct_method.raw / avg_logged_reward,
                               delta=1e-6)
        self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6)
        self.assertAlmostEqual(
            inverse_propensity.normalized,
            inverse_propensity.raw / avg_logged_reward,
            delta=1e-6,
        )
        self.assertAlmostEqual(doubly_robust.raw,
                               direct_method.raw + 1 / 0.5 * (5 - 4) / 3,
                               delta=1e-6)
        self.assertAlmostEqual(doubly_robust.normalized,
                               doubly_robust.raw / avg_logged_reward,
                               delta=1e-6)
        self.assertAlmostEqual(switch_ips.raw,
                               inverse_propensity.raw,
                               delta=1e-6)
        self.assertAlmostEqual(switch_dm.raw, direct_method.raw, delta=1e-6)
        self.assertAlmostEqual(switch_dr_dr.raw, doubly_robust.raw, delta=1e-6)
        self.assertAlmostEqual(switch_dr_dm.raw, direct_method.raw, delta=1e-6)
        logger.info(
            "---------- Finish evaluating eval_greedy=True -----------------")

        logger.info(
            "---------- Start evaluating eval_greedy=False -----------------")
        edp = EvaluationDataPage.create_from_tensors_seq2slate(
            seq2slate_net, reward_net, ptb, eval_greedy=False)
        doubly_robust_estimator = OPEstimatorAdapter(DoublyRobustEstimator())
        dm_estimator = OPEstimatorAdapter(DMEstimator())
        ips_estimator = OPEstimatorAdapter(IPSEstimator())

        doubly_robust = doubly_robust_estimator.estimate(edp)
        inverse_propensity = ips_estimator.estimate(edp)
        direct_method = dm_estimator.estimate(edp)
        self.assertAlmostEqual(
            inverse_propensity.raw,
            (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3,
            delta=1e-6,
        )
        self.assertAlmostEqual(
            inverse_propensity.normalized,
            inverse_propensity.raw / avg_logged_reward,
            delta=1e-6,
        )
        logger.info(
            "---------- Finish evaluating eval_greedy=False -----------------")