def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() assert (len(batch.state.shape) == 2 ), f"{batch.state.shape} is not (batch_size, state_dim)." batch_size, _ = batch.state.shape action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) possible_actions = get_possible_actions_for_gym( batch_size, self.num_actions) possible_next_actions = possible_actions.clone() possible_actions_mask = torch.ones((batch_size, self.num_actions)) possible_next_actions_mask = possible_actions_mask.clone() return rlt.ParametricDqnInput( state=rlt.FeatureData(float_features=batch.state), action=rlt.FeatureData(float_features=action), next_state=rlt.FeatureData(float_features=batch.next_state), next_action=rlt.FeatureData(float_features=next_action), possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.ParametricDqnInput: batch = batch_to_device(batch, self.device) # first preprocess state and action preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"]) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"]) preprocessed_action = self.action_preprocessor( batch["action"], batch["action_presence"]) preprocessed_next_action = self.action_preprocessor( batch["next_action"], batch["next_action_presence"]) return rlt.ParametricDqnInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=rlt.FeatureData(preprocessed_action), next_action=rlt.FeatureData(preprocessed_next_action), reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=batch["not_terminal"].unsqueeze(1), possible_actions=batch["possible_actions"], possible_actions_mask=batch["possible_actions_mask"], possible_next_actions=batch["possible_next_actions"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.DiscreteDqnInput: batch = batch_to_device(batch, self.device) preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"]) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"]) # not terminal iff at least one possible for next action not_terminal = batch["possible_next_actions_mask"].max( dim=1)[0].float() action = F.one_hot(batch["action"].to(torch.int64), self.num_actions) # next action can potentially have value self.num_action if not available next_action = F.one_hot(batch["next_action"].to(torch.int64), self.num_actions + 1)[:, :self.num_actions] return rlt.DiscreteDqnInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=action, next_action=next_action, reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=not_terminal.unsqueeze(1), possible_actions_mask=batch["possible_actions_mask"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() # normalize actions action = rescale_actions( batch.action, new_min=self.train_low, new_max=self.train_high, prev_min=self.action_low, prev_max=self.action_high, ) # only normalize non-terminal non_terminal_indices = (batch.terminal == 0).squeeze(1) next_action = torch.zeros_like(action) next_action[non_terminal_indices] = rescale_actions( batch.next_action[non_terminal_indices], new_min=self.train_low, new_max=self.train_high, prev_min=self.action_low, prev_max=self.action_high, ) dict_batch = { InputColumn.STATE_FEATURES: batch.state, InputColumn.NEXT_STATE_FEATURES: batch.next_state, InputColumn.ACTION: action, InputColumn.NEXT_ACTION: next_action, InputColumn.REWARD: batch.reward, InputColumn.NOT_TERMINAL: not_terminal, InputColumn.STEP: None, InputColumn.TIME_DIFF: None, InputColumn.EXTRAS: rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), } has_candidate_features = False try: dict_batch.update( { InputColumn.CANDIDATE_FEATURES: batch.doc, InputColumn.NEXT_CANDIDATE_FEATURES: batch.next_doc, } ) has_candidate_features = True except AttributeError: pass output = rlt.PolicyNetworkInput.from_dict(dict_batch) if has_candidate_features: output.state = rlt._embed_states(output.state) output.next_state = rlt._embed_states(output.next_state) return output
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) if self.trainer_preprocessor is not None: state = self.trainer_preprocessor(batch.state) next_state = self.trainer_preprocessor(batch.next_state) else: state = rlt.FeatureData(float_features=batch.state) next_state = rlt.FeatureData(float_features=batch.next_state) try: possible_actions_mask = batch.possible_actions_mask.float() except AttributeError: possible_actions_mask = torch.ones_like(action).float() try: possible_next_actions_mask = batch.next_possible_actions_mask.float( ) except AttributeError: possible_next_actions_mask = torch.ones_like(next_action).float() return rlt.DiscreteDqnInput( state=state, action=action, next_state=next_state, next_action=next_action, possible_actions_mask=possible_actions_mask, possible_next_actions_mask=possible_next_actions_mask, reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def test_seq2slate_eval_data_page(self): """ Create 3 slate ranking logs and evaluate using Direct Method, Inverse Propensity Scores, and Doubly Robust. The logs are as follows: state: [1, 0, 0], [0, 1, 0], [0, 0, 1] indices in logged slates: [3, 2], [3, 2], [3, 2] model output indices: [2, 3], [3, 2], [2, 3] logged reward: 4, 5, 7 logged propensities: 0.2, 0.5, 0.4 predicted rewards on logged slates: 2, 4, 6 predicted rewards on model outputted slates: 1, 4, 5 predicted propensities: 0.4, 0.3, 0.7 When eval_greedy=True: Direct Method uses the predicted rewards on model outputted slates. Thus the result is expected to be (1 + 4 + 5) / 3 Inverse Propensity Scores would scale the reward by 1.0 / logged propensities whenever the model output slate matches with the logged slate. Since only the second log matches with the model output, the IPS result is expected to be 5 / 0.5 / 3 Doubly Robust is the sum of the direct method result and propensity-scaled reward difference; the latter is defined as: 1.0 / logged_propensities * (logged reward - predicted reward on logged slate) * Indicator(model slate == logged slate) Since only the second logged slate matches with the model outputted slate, the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3 When eval_greedy=False: Only Inverse Propensity Scores would be accurate. Because it would be too expensive to compute all possible slates' propensities and predicted rewards for Direct Method. The expected IPS = (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3 """ batch_size = 3 state_dim = 3 src_seq_len = 2 tgt_seq_len = 2 candidate_dim = 2 reward_net = FakeSeq2SlateRewardNetwork() seq2slate_net = FakeSeq2SlateTransformerNet() src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1) tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]]) tgt_out_seq = src_seq[ torch.arange(batch_size).repeat_interleave(tgt_seq_len), tgt_out_idx.flatten() - 2, ].reshape(batch_size, tgt_seq_len, candidate_dim) ptb = rlt.PreprocessedRankingInput( state=rlt.FeatureData(float_features=torch.eye(state_dim)), src_seq=rlt.FeatureData(float_features=src_seq), tgt_out_seq=rlt.FeatureData(float_features=tgt_out_seq), src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len), tgt_out_idx=tgt_out_idx, tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]), slate_reward=torch.tensor([4.0, 5.0, 7.0]), extras=rlt.ExtraData( sequence_number=torch.tensor([0, 0, 0]), mdp_id=np.array(["0", "1", "2"]), ), ) edp = EvaluationDataPage.create_from_tensors_seq2slate( seq2slate_net, reward_net, ptb, eval_greedy=True) logger.info( "---------- Start evaluating eval_greedy=True -----------------") doubly_robust_estimator = OPEstimatorAdapter(DoublyRobustEstimator()) dm_estimator = OPEstimatorAdapter(DMEstimator()) ips_estimator = OPEstimatorAdapter(IPSEstimator()) switch_estimator = OPEstimatorAdapter(SwitchEstimator()) switch_dr_estimator = OPEstimatorAdapter(SwitchDREstimator()) doubly_robust = doubly_robust_estimator.estimate(edp) inverse_propensity = ips_estimator.estimate(edp) direct_method = dm_estimator.estimate(edp) # Verify that Switch with low exponent is equivalent to IPS switch_ips = switch_estimator.estimate(edp, exp_base=1) # Verify that Switch with no candidates is equivalent to DM switch_dm = switch_estimator.estimate(edp, candidates=0) # Verify that SwitchDR with low exponent is equivalent to DR switch_dr_dr = switch_dr_estimator.estimate(edp, exp_base=1) # Verify that SwitchDR with no candidates is equivalent to DM switch_dr_dm = switch_dr_estimator.estimate(edp, candidates=0) logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}") avg_logged_reward = (4 + 5 + 7) / 3 self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6) self.assertAlmostEqual(direct_method.normalized, direct_method.raw / avg_logged_reward, delta=1e-6) self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6) self.assertAlmostEqual( inverse_propensity.normalized, inverse_propensity.raw / avg_logged_reward, delta=1e-6, ) self.assertAlmostEqual(doubly_robust.raw, direct_method.raw + 1 / 0.5 * (5 - 4) / 3, delta=1e-6) self.assertAlmostEqual(doubly_robust.normalized, doubly_robust.raw / avg_logged_reward, delta=1e-6) self.assertAlmostEqual(switch_ips.raw, inverse_propensity.raw, delta=1e-6) self.assertAlmostEqual(switch_dm.raw, direct_method.raw, delta=1e-6) self.assertAlmostEqual(switch_dr_dr.raw, doubly_robust.raw, delta=1e-6) self.assertAlmostEqual(switch_dr_dm.raw, direct_method.raw, delta=1e-6) logger.info( "---------- Finish evaluating eval_greedy=True -----------------") logger.info( "---------- Start evaluating eval_greedy=False -----------------") edp = EvaluationDataPage.create_from_tensors_seq2slate( seq2slate_net, reward_net, ptb, eval_greedy=False) doubly_robust_estimator = OPEstimatorAdapter(DoublyRobustEstimator()) dm_estimator = OPEstimatorAdapter(DMEstimator()) ips_estimator = OPEstimatorAdapter(IPSEstimator()) doubly_robust = doubly_robust_estimator.estimate(edp) inverse_propensity = ips_estimator.estimate(edp) direct_method = dm_estimator.estimate(edp) self.assertAlmostEqual( inverse_propensity.raw, (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3, delta=1e-6, ) self.assertAlmostEqual( inverse_propensity.normalized, inverse_propensity.raw / avg_logged_reward, delta=1e-6, ) logger.info( "---------- Finish evaluating eval_greedy=False -----------------")