Exemple #1
0
 def input_prototype(self):
     return rlt.PreprocessedStateAction(
         state=rlt.PreprocessedFeatureVector(
             float_features=torch.randn(1, 1, self.state_dim)),
         action=rlt.PreprocessedFeatureVector(
             float_features=torch.randn(1, 1, self.action_dim)),
     )
Exemple #2
0
    def __call__(self,
                 batch: rlt.RawTrainingBatch) -> rlt.PreprocessedTrainingBatch:
        training_input = batch.training_input
        assert isinstance(training_input,
                          rlt.RawMemoryNetworkInput), "Wrong Type: {}".format(
                              str(type(training_input)))

        preprocessed_state = self.state_preprocessor(
            training_input.state.float_features.value,
            training_input.state.float_features.presence,
        )
        preprocessed_next_state = self.state_preprocessor(
            training_input.next_state.float_features.value,
            training_input.next_state.float_features.presence,
        )
        new_training_input = training_input.preprocess_tensors(
            state=preprocessed_state, next_state=preprocessed_next_state)
        preprocessed_batch = batch.preprocess(new_training_input)
        assert isinstance(new_training_input,
                          rlt.PreprocessedMemoryNetworkInput)
        preprocessed_batch = preprocessed_batch._replace(
            training_input=new_training_input._replace(
                state=rlt.PreprocessedFeatureVector(
                    float_features=new_training_input.state.float_features.
                    reshape(-1, self.seq_len, self.state_dim)),
                action=new_training_input.action.reshape(
                    -1, self.seq_len, self.action_dim),
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=new_training_input.next_state.
                    float_features.reshape(-1, self.seq_len, self.state_dim)),
                reward=new_training_input.reward.reshape(-1, self.seq_len),
                not_terminal=new_training_input.not_terminal.reshape(
                    -1, self.seq_len),
            ))
        return preprocessed_batch
 def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
     obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
         train_batch)
     obs = torch.tensor(obs).squeeze(2)
     action = torch.tensor(action).float()
     reward = torch.tensor(reward).unsqueeze(1)
     next_obs = torch.tensor(next_obs).squeeze(2)
     next_action = torch.tensor(next_action)
     not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float()
     idxs = torch.tensor(idxs)
     possible_actions_mask = torch.tensor(possible_actions_mask).float()
     log_prob = torch.tensor(log_prob)
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedPolicyNetworkInput(
             state=rlt.PreprocessedFeatureVector(float_features=obs),
             action=rlt.PreprocessedFeatureVector(float_features=action),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=next_obs),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=next_action),
             reward=reward,
             not_terminal=not_terinal,
             step=None,
             time_diff=None,
         ),
         extras=rlt.ExtraData(),
     )
 def __call__(  # type: ignore
     self, batch: Dict[str, torch.Tensor]
 ) -> rlt.DiscreteDqnInput:
     batch = batch_to_device(batch, self.device)
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"]
     )
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"], batch["next_state_features_presence"]
     )
     # not terminal iff at least one possible for next action
     not_terminal = batch["possible_next_actions_mask"].max(dim=1)[0].float()
     action = F.one_hot(batch["action"].to(torch.int64), self.num_actions)
     # next action can potentially have value self.num_action if not available
     next_action = F.one_hot(
         batch["next_action"].to(torch.int64), self.num_actions + 1
     )[:, : self.num_actions]
     return rlt.DiscreteDqnInput(
         state=rlt.PreprocessedFeatureVector(preprocessed_state),
         next_state=rlt.PreprocessedFeatureVector(preprocessed_next_state),
         action=action,
         next_action=next_action,
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=not_terminal.unsqueeze(1),
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1).cpu().numpy(),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
 def as_slate_q_training_batch(self):
     batch_size, state_dim = self.states.shape
     action_dim = self.actions.shape[1]
     return rlt.PreprocessedSlateQInput(
         state=rlt.PreprocessedFeatureVector(float_features=self.states),
         next_state=rlt.PreprocessedFeatureVector(
             float_features=self.next_states),
         action=rlt.PreprocessedSlateFeatureVector(
             float_features=self.possible_actions_state_concat[:,
                                                               state_dim:].
             view(batch_size, -1, action_dim),
             item_mask=self.possible_actions_mask,
             item_probability=self.propensities,
         ),
         next_action=rlt.PreprocessedSlateFeatureVector(
             float_features=self.
             possible_next_actions_state_concat[:, state_dim:].view(
                 batch_size, -1, action_dim),
             item_mask=self.possible_next_actions_mask,
             item_probability=self.next_propensities,
         ),
         reward=self.rewards,
         reward_mask=self.rewards_mask,
         time_diff=self.time_diffs,
         step=self.step,
         not_terminal=self.not_terminal,
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
Exemple #6
0
    def acc_rewards_of_one_solution(
        self, init_state: torch.Tensor, solution: torch.Tensor, solution_idx: int
    ):
        """
        ensemble_pop_size trajectories will be sampled to evaluate a
        CEM solution. Each trajectory is generated by one world model

        :param init_state: its shape is (state_dim, )
        :param solution: its shape is (plan_horizon_length, action_dim)
        :param solution_idx: the index of the solution
        :return reward: Reward of each of ensemble_pop_size trajectories
        """
        reward_matrix = np.zeros((self.ensemble_pop_size, self.plan_horizon_length))

        for i in range(self.ensemble_pop_size):
            state = init_state
            mem_net_idx = np.random.randint(0, len(self.mem_net_list))
            for j in range(self.plan_horizon_length):
                # world_model_input.state shape:
                # (1, 1, state_dim)
                # world_model_input.action shape:
                # (1, 1, action_dim)
                world_model_input = rlt.PreprocessedStateAction(
                    state=rlt.PreprocessedFeatureVector(
                        float_features=state.reshape((1, 1, self.state_dim))
                    ),
                    action=rlt.PreprocessedFeatureVector(
                        float_features=solution[j, :].reshape((1, 1, self.action_dim))
                    ),
                )
                reward, next_state, not_terminal, not_terminal_prob = self.sample_reward_next_state_terminal(
                    world_model_input, self.mem_net_list[mem_net_idx]
                )
                reward_matrix[i, j] = reward * (self.gamma ** j)

                if not not_terminal:
                    logger.debug(
                        f"Solution {solution_idx}: predict terminal at step {j}"
                        f" with prob. {1.0 - not_terminal_prob}"
                    )

                if not not_terminal:
                    break

                state = next_state

        return np.sum(reward_matrix, axis=1)
 def forward(
     self,
     state: torch.Tensor,
     src_seq: torch.Tensor,
     tgt_out_seq: torch.Tensor,
     src_src_mask: torch.Tensor,
     tgt_out_idx: torch.Tensor,
 ) -> torch.Tensor:
     return self.model(
         rlt.PreprocessedRankingInput(
             state=rlt.PreprocessedFeatureVector(float_features=state),
             src_seq=rlt.PreprocessedFeatureVector(float_features=src_seq),
             tgt_out_seq=rlt.PreprocessedFeatureVector(
                 float_features=tgt_out_seq),
             src_src_mask=src_src_mask,
             tgt_out_idx=tgt_out_idx,
         )).predicted_reward
 def as_policy_network_training_batch(self):
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedPolicyNetworkInput(
             state=rlt.PreprocessedFeatureVector(
                 float_features=self.states),
             action=rlt.PreprocessedFeatureVector(
                 float_features=self.actions),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.next_states),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=self.next_actions),
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Exemple #9
0
 def __call__(self,
              batch: rlt.RawTrainingBatch) -> rlt.PreprocessedTrainingBatch:
     preprocessed_batch = super().__call__(batch)
     training_input = preprocessed_batch.training_input
     assert isinstance(training_input, rlt.PreprocessedMemoryNetworkInput)
     preprocessed_batch = preprocessed_batch._replace(
         training_input=training_input._replace(
             state=rlt.PreprocessedFeatureVector(
                 float_features=training_input.state.float_features.reshape(
                     -1, self.seq_len, self.state_dim)),
             action=training_input.action.reshape(-1, self.seq_len,
                                                  self.action_dim),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=training_input.next_state.float_features.
                 reshape(-1, self.seq_len, self.state_dim)),
             reward=training_input.reward.reshape(-1, self.seq_len),
             not_terminal=training_input.not_terminal.reshape(
                 -1, self.seq_len),
         ))
     return preprocessed_batch
Exemple #10
0
    def sample_memories(self,
                        batch_size,
                        use_gpu=False,
                        batch_first=False) -> rlt.PreprocessedTrainingBatch:
        """
        :param batch_size: number of samples to return
        :param use_gpu: whether to put samples on gpu
        :param batch_first: If True, the first dimension of data is batch_size.
            If False (default), the first dimension is SEQ_LEN. Therefore,
            state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default,
            MDN-RNN consumes data with SEQ_LEN as the first dimension.
        """
        sample_indices = np.random.randint(self.memory_size, size=batch_size)
        device = (
            torch.device("cuda")
            if use_gpu else torch.device("cpu")  # type: ignore
        )
        # state/next state shape: batch_size x seq_len x state_dim
        # action shape: batch_size x seq_len x action_dim
        # reward/not_terminal shape: batch_size x seq_len
        state, action, next_state, reward, not_terminal = map(
            lambda x: stack(x).float().to(device),
            zip(*self.deque_sample(sample_indices)),
        )

        if not batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                state, action, next_state, reward, not_terminal)

        training_input = rlt.PreprocessedMemoryNetworkInput(
            state=rlt.PreprocessedFeatureVector(float_features=state),
            reward=reward,
            time_diff=torch.ones_like(reward).float(),
            action=action,
            next_state=rlt.PreprocessedFeatureVector(
                float_features=next_state),
            not_terminal=not_terminal,
            step=None,
        )
        return rlt.PreprocessedTrainingBatch(training_input=training_input,
                                             extras=None)
    def as_cem_training_batch(self, batch_first=False):
        """
        Generate one-step samples needed by CEM trainer.
        The samples will be used to train an ensemble of world models used by CEM.

        If batch_first = True:
            state/next state shape: batch_size x 1 x state_dim
            action shape: batch_size x 1 x action_dim
            reward/terminal shape: batch_size x 1
        else (default):
             state/next state shape: 1 x batch_size x state_dim
             action shape: 1 x batch_size x action_dim
             reward/terminal shape: 1 x batch_size
        """
        if batch_first:
            seq_len_dim = 1
            reward, not_terminal = self.rewards, self.not_terminal
        else:
            seq_len_dim = 0
            reward, not_terminal = transpose(self.rewards, self.not_terminal)
        training_input = rlt.PreprocessedMemoryNetworkInput(
            state=rlt.PreprocessedFeatureVector(
                float_features=self.states.unsqueeze(seq_len_dim)),
            action=self.actions.unsqueeze(seq_len_dim),
            next_state=rlt.PreprocessedFeatureVector(
                float_features=self.next_states.unsqueeze(seq_len_dim)),
            reward=reward,
            not_terminal=not_terminal,
            step=self.step,
            time_diff=self.time_diffs,
        )
        return rlt.PreprocessedTrainingBatch(
            training_input=training_input,
            extras=rlt.ExtraData(
                mdp_id=self.mdp_ids,
                sequence_number=self.sequence_numbers,
                action_probability=self.propensities,
                max_num_actions=self.max_num_actions,
                metrics=self.metrics,
            ),
        )
Exemple #12
0
 def internal_prediction(
     self, state: torch.Tensor
 ) -> Union[rlt.SacPolicyActionSet, rlt.DqnPolicyActionSet]:
     """
     Only used by Gym. Return the predicted next action
     """
     input = rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector(
         float_features=state))
     output = self.cem_planner_network(input)
     if not self.cem_planner_network.discrete_action:
         return rlt.SacPolicyActionSet(greedy=output, greedy_propensity=1.0)
     return rlt.DqnPolicyActionSet(greedy=output[0])
 def as_discrete_maxq_training_batch(self):
     return rlt.PreprocessedDiscreteDqnInput(
         state=rlt.PreprocessedFeatureVector(float_features=self.states),
         action=self.actions,
         next_state=rlt.PreprocessedFeatureVector(
             float_features=self.next_states),
         next_action=self.next_actions,
         possible_actions_mask=self.possible_actions_mask,
         possible_next_actions_mask=self.possible_next_actions_mask,
         reward=self.rewards,
         not_terminal=self.not_terminal,
         step=self.step,
         time_diff=self.time_diffs,
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
 def as_parametric_maxq_training_batch(self):
     state_dim = self.states.shape[1]
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedParametricDqnInput(
             state=rlt.PreprocessedFeatureVector(
                 float_features=self.states),
             action=rlt.PreprocessedFeatureVector(
                 float_features=self.actions),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.next_states),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=self.next_actions),
             tiled_next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, :state_dim]),
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=rlt.PreprocessedFeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, state_dim:]),
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Exemple #15
0
    def score(preprocessed_obs: rlt.PreprocessedState) -> torch.Tensor:
        tiled_state = preprocessed_obs.repeat_interleave(repeats=num_actions, axis=0)

        actions = rlt.PreprocessedFeatureVector(float_features=torch.eye(num_actions))
        preprocessed_obs = rlt.PreprocessedStateAction(tiled_state.state, actions)

        q_network.eval()
        scores = q_network(preprocessed_obs).q_value.view(-1, num_actions)
        assert (
            scores.size(1) == num_actions
        ), f"scores size is {scores.size(0)}, num_actions is {num_actions}"
        q_network.train()
        return F.log_softmax(scores, dim=-1)
    def test_discrete_wrapper_with_id_list(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        dqn = FullyConnectedDQNWithEmbedding(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
            model_feature_config=state_feature_config,
            embedding_dim=8,
        )
        dqn_with_preprocessor = DiscreteDqnWithPreprocessorWithIdList(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapperWithIdList(
            dqn_with_preprocessor, action_names, state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_values = wrapper(*input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        feature_id_to_name = {
            config.feature_id: config.name
            for config in state_feature_config.id_list_feature_configs
        }
        state_id_list_features = {
            feature_id_to_name[k]: v
            for k, v in input_prototype[1].items()
        }
        expected_output = dqn(
            rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector(
                float_features=state_preprocessor(*input_prototype[0]),
                id_list_features=state_id_list_features,
            ))).q_values
        self.assertTrue((expected_output == q_values).all())
Exemple #17
0
 def forward(
     self,
     state_with_presence: Tuple[torch.Tensor, torch.Tensor],
     state_id_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor]],
 ):
     preprocessed_state = self.state_preprocessor(state_with_presence[0],
                                                  state_with_presence[1])
     id_list_features = {
         id_list_feature_config.name:
         state_id_list_features[id_list_feature_config.feature_id]
         for id_list_feature_config in self.id_list_feature_configs
     }
     state_feature_vector = rlt.PreprocessedState(
         state=rlt.PreprocessedFeatureVector(
             float_features=preprocessed_state,
             id_list_features=id_list_features))
     q_values = self.model(state_feature_vector).q_values
     return q_values
Exemple #18
0
    def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
        obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
            train_batch)
        batch_size = obs.shape[0]

        obs = torch.tensor(obs).squeeze(2)
        action = torch.tensor(action).float()
        next_obs = torch.tensor(next_obs).squeeze(2)
        next_action = torch.tensor(next_action).to(torch.float32)
        reward = torch.tensor(reward).unsqueeze(1)
        not_terminal = 1 - torch.tensor(terminal).unsqueeze(1).to(torch.uint8)
        possible_actions_mask = torch.ones_like(action).to(torch.bool)

        tiled_next_state = torch.repeat_interleave(next_obs,
                                                   repeats=num_actions,
                                                   axis=0)
        possible_next_actions = torch.eye(num_actions).repeat(batch_size, 1)
        possible_next_actions_mask = not_terminal.repeat(1, num_actions).to(
            torch.bool)
        return rlt.PreprocessedTrainingBatch(
            rlt.PreprocessedParametricDqnInput(
                state=rlt.PreprocessedFeatureVector(float_features=obs),
                action=rlt.PreprocessedFeatureVector(float_features=action),
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=next_obs),
                next_action=rlt.PreprocessedFeatureVector(
                    float_features=next_action),
                possible_actions=None,
                possible_actions_mask=possible_actions_mask,
                possible_next_actions=rlt.PreprocessedFeatureVector(
                    float_features=possible_next_actions),
                possible_next_actions_mask=possible_next_actions_mask,
                tiled_next_state=rlt.PreprocessedFeatureVector(
                    float_features=tiled_next_state),
                reward=reward,
                not_terminal=not_terminal,
                step=None,
                time_diff=None,
            ),
            extras=rlt.ExtraData(),
        )
Exemple #19
0
    def get_loss(
        self,
        training_batch: rlt.PreprocessedTrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """
        Compute losses:
            GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2)
            + MSE(reward, predicted_reward)
            + BCE(not_terminal, logit_not_terminal)

        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
            approximately linearly with STATE_DIM, the feature size of states. All losses
            are averaged both on the batch and the sequence dimensions (the two first
            dimensions).

        :param training_batch:
            training_batch.learning_input has these fields:
            - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            - reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            the first two dimensions may be swapped depending on batch_first

        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss

        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput)
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action,
                learning_input.next_state.float_features,
                learning_input.reward,
                learning_input.not_terminal,  # type: ignore
            )
            learning_input = rlt.PreprocessedMemoryNetworkInput(  # type: ignore
                state=rlt.PreprocessedFeatureVector(float_features=state),
                reward=reward,
                time_diff=torch.ones_like(reward).float(),
                action=action,
                not_terminal=not_terminal,
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=next_state),
                step=None,
            )

        mdnrnn_input = rlt.PreprocessedStateAction(
            state=learning_input.state,  # type: ignore
            action=rlt.PreprocessedFeatureVector(
                float_features=learning_input.action),  # type: ignore
        )
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, nts = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        next_state = learning_input.next_state.float_features
        not_terminal = learning_input.not_terminal  # type: ignore
        reward = learning_input.reward
        if self.params.fit_only_one_next_step:
            next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple(
                map(
                    lambda x: x[-1:],
                    (next_state, not_terminal, reward, mus, sigmas, logpi, nts,
                     rs),
                ))

        gmm = (gmm_loss(next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
Exemple #20
0
    def create_from_tensors_parametric_dqn(
        cls,
        trainer: ParametricDQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: rlt.PreprocessedFeatureVector,
        max_num_actions: int,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        state_action_pairs = rlt.PreprocessedStateAction(state=states,
                                                         action=actions)
        tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
            -1, states.float_features.shape[1])
        assert possible_actions is not None
        # Get Q-value of action taken
        possible_actions_state_concat = rlt.PreprocessedStateAction(
            state=rlt.PreprocessedFeatureVector(float_features=tiled_state),
            action=possible_actions,
        )

        # FIXME: model_values, model_values_for_logged_action, and model_metrics_values
        # should be calculated using q_network_cpe (as in discrete dqn).
        # q_network_cpe has not been added in parametric dqn yet.
        model_values = trainer.q_network(
            possible_actions_state_concat).q_value  # type: ignore
        optimal_q_values, _ = trainer.get_detached_q_values(
            possible_actions_state_concat.state,
            possible_actions_state_concat.action)
        eval_action_idxs = None

        assert (model_values.shape[1] == 1
                and model_values.shape[0] == possible_actions_mask.shape[0] *
                possible_actions_mask.shape[1]), (
                    "Invalid shapes: " + str(model_values.shape) + " != " +
                    str(possible_actions_mask.shape))
        model_values = model_values.reshape(possible_actions_mask.shape)
        optimal_q_values = optimal_q_values.reshape(
            possible_actions_mask.shape)
        model_propensities = masked_softmax(optimal_q_values,
                                            possible_actions_mask,
                                            trainer.rl_temperature)

        rewards_and_metric_rewards = trainer.reward_network(
            possible_actions_state_concat).q_value  # type: ignore
        model_rewards = rewards_and_metric_rewards[:, :1]
        assert (model_rewards.shape[0] *
                model_rewards.shape[1] == possible_actions_mask.shape[0] *
                possible_actions_mask.shape[1]), (
                    "Invalid shapes: " + str(model_rewards.shape) + " != " +
                    str(possible_actions_mask.shape))
        model_rewards = model_rewards.reshape(possible_actions_mask.shape)

        model_metrics = rewards_and_metric_rewards[:, 1:]
        model_metrics = model_metrics.reshape(possible_actions_mask.shape[0],
                                              -1)

        model_values_for_logged_action = trainer.q_network(
            state_action_pairs).q_value
        model_rewards_and_metrics_for_logged_action = trainer.reward_network(
            state_action_pairs).q_value
        model_rewards_for_logged_action = model_rewards_and_metrics_for_logged_action[:, :
                                                                                      1]

        action_dim = possible_actions.float_features.shape[1]
        action_mask = torch.all(
            possible_actions.float_features.view(
                -1, max_num_actions,
                action_dim) == actions.float_features.unsqueeze(dim=1),
            dim=2,
        ).float()
        assert torch.all(action_mask.sum(dim=1) == 1)
        num_metrics = model_metrics.shape[1] // max_num_actions

        model_metrics_values = None
        model_metrics_for_logged_action = None
        model_metrics_values_for_logged_action = None
        if num_metrics > 0:
            # FIXME: calculate model_metrics_values when q_network_cpe is added
            # to parametric dqn
            model_metrics_values = model_values.repeat(1, num_metrics)

        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=
            model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
Exemple #21
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_policy_network_training_batch"):
            training_batch = training_batch.as_policy_network_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        next_state = learning_input.next_state
        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        action = self._maybe_scale_action_in_train(action.float_features)

        max_action = (self.max_action_range_tensor_training
                      if self.max_action_range_tensor_training else torch.ones(
                          action.shape, device=self.device))
        min_action = (self.min_action_range_tensor_serving
                      if self.min_action_range_tensor_serving else
                      -torch.ones(action.shape, device=self.device))

        # Compute current value estimates
        current_state_action = rlt.PreprocessedStateAction(
            state=state,
            action=rlt.PreprocessedFeatureVector(float_features=action))
        q1_value = self.q1_network(current_state_action).q_value
        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
        actor_action = self.actor_network(
            rlt.PreprocessedState(state=state)).action

        # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s')))
        with torch.no_grad():
            next_actor = self.actor_network_target(
                rlt.PreprocessedState(state=next_state)).action
            next_actor += (torch.randn_like(next_actor) *
                           self.target_policy_smoothing).clamp(
                               -self.noise_clip, self.noise_clip)
            next_actor = torch.max(torch.min(next_actor, max_action),
                                   min_action)
            next_state_actor = rlt.PreprocessedStateAction(
                state=next_state,
                action=rlt.PreprocessedFeatureVector(
                    float_features=next_actor),
            )
            next_state_value = self.q1_network_target(next_state_actor).q_value

            if self.q2_network is not None:
                next_state_value = torch.min(
                    next_state_value,
                    self.q2_network_target(next_state_actor).q_value)

            target_q_value = (
                reward + self.gamma * next_state_value * not_done_mask.float())

        # Optimize Q1 and Q2
        q1_loss = F.mse_loss(q1_value, target_q_value)
        q1_loss.backward()
        self._maybe_run_optimizer(self.q1_network_optimizer,
                                  self.minibatches_per_step)
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            q2_loss.backward()
            self._maybe_run_optimizer(self.q2_network_optimizer,
                                      self.minibatches_per_step)

        # Only update actor and target networks after a fixed number of Q updates
        if self.minibatch % self.delayed_policy_update == 0:
            actor_loss = -self.q1_network(
                rlt.PreprocessedStateAction(
                    state=state,
                    action=rlt.PreprocessedFeatureVector(
                        float_features=actor_action),
                )).q_value.mean()
            actor_loss.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            # Use the soft update rule to update the target networks
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            self._maybe_soft_update(
                self.actor_network,
                self.actor_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq != 0
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
        )
Exemple #22
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_policy_network_training_batch"):
            training_batch = training_batch.as_policy_network_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = action._replace(float_features=rescale_torch_tensor(
                action.float_features,
                new_min=self.min_action_range_tensor_training,
                new_max=self.max_action_range_tensor_training,
                prev_min=self.min_action_range_tensor_serving,
                prev_max=self.max_action_range_tensor_serving,
            ))

        # We need to zero out grad here because gradient from actor update
        # should not be used in Q-network update
        self.actor_network_optimizer.zero_grad()
        self.q1_network_optimizer.zero_grad()
        if self.q2_network is not None:
            self.q2_network_optimizer.zero_grad()
        if self.value_network is not None:
            self.value_network_optimizer.zero_grad()

        with torch.enable_grad():
            #
            # First, optimize Q networks; minimizing MSE between
            # Q(s, a) & r + discount * V'(next_s)
            #

            current_state_action = rlt.PreprocessedStateAction(state=state,
                                                               action=action)
            q1_value = self.q1_network(current_state_action).q_value
            if self.q2_network:
                q2_value = self.q2_network(current_state_action).q_value
            actor_output = self.actor_network(
                rlt.PreprocessedState(state=state))

            # Optimize Alpha
            if self.alpha_optimizer is not None:
                alpha_loss = -(self.log_alpha *
                               (actor_output.log_prob +
                                self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.entropy_temperature = self.log_alpha.exp()

            with torch.no_grad():
                if self.value_network is not None:
                    next_state_value = self.value_network_target(
                        learning_input.next_state.float_features)
                else:
                    next_state_actor_output = self.actor_network(
                        rlt.PreprocessedState(state=learning_input.next_state))
                    next_state_actor_action = rlt.PreprocessedStateAction(
                        state=learning_input.next_state,
                        action=rlt.PreprocessedFeatureVector(
                            float_features=next_state_actor_output.action),
                    )
                    next_state_value = self.q1_network_target(
                        next_state_actor_action).q_value

                    if self.q2_network is not None:
                        target_q2_value = self.q2_network_target(
                            next_state_actor_action).q_value
                        next_state_value = torch.min(next_state_value,
                                                     target_q2_value)

                    log_prob_a = self.actor_network.get_log_prob(
                        learning_input.next_state,
                        next_state_actor_output.action)
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    next_state_value -= self.entropy_temperature * log_prob_a

                if self.gamma > 0.0:
                    target_q_value = (
                        reward +
                        discount * next_state_value * not_done_mask.float())
                else:
                    # This is useful in debugging instability issues
                    target_q_value = reward

            q1_loss = F.mse_loss(q1_value, target_q_value)
            q1_loss.backward()
            self._maybe_run_optimizer(self.q1_network_optimizer,
                                      self.minibatches_per_step)
            if self.q2_network:
                q2_loss = F.mse_loss(q2_value, target_q_value)
                q2_loss.backward()
                self._maybe_run_optimizer(self.q2_network_optimizer,
                                          self.minibatches_per_step)

            #
            # Second, optimize the actor; minimizing KL-divergence between action propensity
            # & softmax of value. Due to reparameterization trick, it ends up being
            # log_prob(actor_action) - Q(s, actor_action)
            #

            state_actor_action = rlt.PreprocessedStateAction(
                state=state,
                action=rlt.PreprocessedFeatureVector(
                    float_features=actor_output.action),
            )
            q1_actor_value = self.q1_network(state_actor_action).q_value
            min_q_actor_value = q1_actor_value
            if self.q2_network:
                q2_actor_value = self.q2_network(state_actor_action).q_value
                min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

            actor_loss = (self.entropy_temperature * actor_output.log_prob -
                          min_q_actor_value)
            # Do this in 2 steps so we can log histogram of actor loss
            actor_loss_mean = actor_loss.mean()

            if self.add_kld_to_loss:
                if self.apply_kld_on_mean:
                    action_batch_m = torch.mean(actor_output.action_mean,
                                                axis=0)
                    action_batch_v = torch.var(actor_output.action_mean,
                                               axis=0)
                else:
                    action_batch_m = torch.mean(actor_output.action, axis=0)
                    action_batch_v = torch.var(actor_output.action, axis=0)
                kld = (0.5 * ((action_batch_v +
                               (action_batch_m - self.action_emb_mean)**2) /
                              self.action_emb_variance - 1 +
                              self.action_emb_variance.log() -
                              action_batch_v.log()).sum())

                actor_loss_mean += self.kld_weight * kld

            actor_loss_mean.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            #
            # Lastly, if applicable, optimize value network; minimizing MSE between
            # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
            #

            if self.value_network is not None:
                state_value = self.value_network(state.float_features)

                if self.logged_action_uniform_prior:
                    log_prob_a = torch.zeros_like(min_q_actor_value)
                    target_value = min_q_actor_value
                else:
                    with torch.no_grad():
                        log_prob_a = actor_output.log_prob
                        log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                        target_value = (min_q_actor_value -
                                        self.entropy_temperature * log_prob_a)

                value_loss = F.mse_loss(state_value, target_value.detach())
                value_loss.backward()
                self._maybe_run_optimizer(self.value_network_optimizer,
                                          self.minibatches_per_step)

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq != 0
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_scalar("entropy_temperature",
                                            self.entropy_temperature)
            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target",
                                                   target_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)
            if self.add_kld_to_loss:
                SummaryWriterContext.add_histogram("kld/mean", action_batch_m)
                SummaryWriterContext.add_histogram("kld/var", action_batch_v)
                SummaryWriterContext.add_scalar("kld/kld", kld)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )
    def _simulated_training_input(self, training_input, sim_tgt_out_idx,
                                  sim_distance, device):
        batch_size, max_tgt_seq_len = sim_tgt_out_idx.shape
        _, max_src_seq_len, candidate_feat_dim = (
            training_input.src_seq.float_features.shape)

        # candidates + padding_symbol + decoder_start_symbol
        candidate_size = max_src_seq_len + 2
        src_seq_augment = torch.zeros(batch_size,
                                      candidate_size,
                                      candidate_feat_dim,
                                      device=device)
        src_seq_augment[:, 2:, :] = training_input.src_seq.float_features

        sim_tgt_in_idx = torch.zeros_like(sim_tgt_out_idx).long()
        sim_tgt_in_idx[:, 0] = DECODER_START_SYMBOL
        sim_tgt_in_idx[:, 1:] = sim_tgt_out_idx[:, :-1]

        sim_tgt_in_seq = rlt.PreprocessedFeatureVector(
            float_features=src_seq_augment[
                torch.arange(batch_size, device=device
                             ).repeat_interleave(  # type: ignore
                                 max_tgt_seq_len),
                sim_tgt_in_idx.flatten(), ].view(batch_size, max_tgt_seq_len,
                                                 candidate_feat_dim))
        sim_tgt_out_seq = rlt.PreprocessedFeatureVector(
            float_features=src_seq_augment[
                torch.arange(batch_size, device=device
                             ).repeat_interleave(  # type: ignore
                                 max_tgt_seq_len),
                sim_tgt_out_idx.flatten(), ].view(batch_size, max_tgt_seq_len,
                                                  candidate_feat_dim))
        sim_tgt_out_probs = torch.tensor([1.0 / len(self.permutation_index)],
                                         device=self.device).repeat(batch_size)

        if self.reward_net is None:
            self.reward_net = _load_reward_net(self.reward_net_path,
                                               self.use_gpu)
        slate_reward = (self.reward_net(
            training_input.state.float_features,
            training_input.src_seq.float_features,
            sim_tgt_out_seq.float_features,
            training_input.src_src_mask,
            sim_tgt_out_idx,
        ).squeeze().detach())
        # guard-rail reward prediction range
        reward_clamp = self.parameters.simulation_reward_clamp
        if reward_clamp is not None:
            slate_reward = torch.clamp(slate_reward,
                                       min=reward_clamp.clamp_min,
                                       max=reward_clamp.clamp_max)
        # guard-rail sequence similarity
        distance_penalty = self.parameters.simulation_distance_penalty
        if distance_penalty is not None:
            slate_reward += distance_penalty * (self.MAX_DISTANCE -
                                                sim_distance)

        on_policy_input = rlt.PreprocessedRankingInput(
            state=training_input.state,
            src_seq=training_input.src_seq,
            src_src_mask=training_input.src_src_mask,
            tgt_in_seq=sim_tgt_in_seq,
            tgt_out_seq=sim_tgt_out_seq,
            tgt_tgt_mask=training_input.tgt_tgt_mask,
            slate_reward=slate_reward,
            src_in_idx=training_input.src_in_idx,
            tgt_in_idx=sim_tgt_in_idx,
            tgt_out_idx=sim_tgt_out_idx,
            tgt_out_probs=sim_tgt_out_probs,
        )
        return on_policy_input
    def test_seq2slate_eval_data_page(self):
        """
        Create 3 slate ranking logs and evaluate using Direct Method, Inverse
        Propensity Scores, and Doubly Robust.

        The logs are as follows:
        state: [1, 0, 0], [0, 1, 0], [0, 0, 1]
        indices in logged slates: [3, 2], [3, 2], [3, 2]
        model output indices: [2, 3], [3, 2], [2, 3]
        logged reward: 4, 5, 7
        logged propensities: 0.2, 0.5, 0.4
        predicted rewards on logged slates: 2, 4, 6
        predicted rewards on model outputted slates: 1, 4, 5
        predicted propensities: 0.4, 0.3, 0.7

        When eval_greedy=True:

        Direct Method uses the predicted rewards on model outputted slates.
        Thus the result is expected to be (1 + 4 + 5) / 3

        Inverse Propensity Scores would scale the reward by 1.0 / logged propensities
        whenever the model output slate matches with the logged slate.
        Since only the second log matches with the model output, the IPS result
        is expected to be 5 / 0.5 / 3

        Doubly Robust is the sum of the direct method result and propensity-scaled
        reward difference; the latter is defined as:
        1.0 / logged_propensities * (logged reward - predicted reward on logged slate)
         * Indicator(model slate == logged slate)
        Since only the second logged slate matches with the model outputted slate,
        the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3


        When eval_greedy=False:

        Only Inverse Propensity Scores would be accurate. Because it would be too
        expensive to compute all possible slates' propensities and predicted rewards
        for Direct Method.

        The expected IPS = (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3
        """
        batch_size = 3
        state_dim = 3
        src_seq_len = 2
        tgt_seq_len = 2
        candidate_dim = 2

        reward_net = FakeSeq2SlateRewardNetwork()
        seq2slate_net = FakeSeq2SlateTransformerNet()

        src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1)
        tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]])
        tgt_out_seq = src_seq[torch.arange(batch_size).
                              repeat_interleave(tgt_seq_len),  # type: ignore
                              tgt_out_idx.flatten() - 2, ].reshape(
                                  batch_size, tgt_seq_len, candidate_dim)

        ptb = rlt.PreprocessedTrainingBatch(
            training_input=rlt.PreprocessedRankingInput(
                state=rlt.PreprocessedFeatureVector(
                    float_features=torch.eye(state_dim)),
                src_seq=rlt.PreprocessedFeatureVector(float_features=src_seq),
                tgt_out_seq=rlt.PreprocessedFeatureVector(
                    float_features=tgt_out_seq),
                src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len),
                tgt_out_idx=tgt_out_idx,
                tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]),
                slate_reward=torch.tensor([4.0, 5.0, 7.0]),
            ),
            extras=rlt.ExtraData(
                sequence_number=torch.tensor([0, 0, 0]),
                mdp_id=np.array(["0", "1", "2"]),
            ),
        )

        edp = EvaluationDataPage.create_from_tensors_seq2slate(
            seq2slate_net, reward_net, ptb.training_input, eval_greedy=True)
        logger.info(
            "---------- Start evaluating eval_greedy=True -----------------")
        doubly_robust_estimator = DoublyRobustEstimator()
        direct_method, inverse_propensity, doubly_robust = doubly_robust_estimator.estimate(
            edp)
        logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}")

        avg_logged_reward = (4 + 5 + 7) / 3
        self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6)
        self.assertAlmostEqual(direct_method.normalized,
                               direct_method.raw / avg_logged_reward,
                               delta=1e-6)
        self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6)
        self.assertAlmostEqual(
            inverse_propensity.normalized,
            inverse_propensity.raw / avg_logged_reward,
            delta=1e-6,
        )
        self.assertAlmostEqual(doubly_robust.raw,
                               direct_method.raw + 1 / 0.5 * (5 - 4) / 3,
                               delta=1e-6)
        self.assertAlmostEqual(doubly_robust.normalized,
                               doubly_robust.raw / avg_logged_reward,
                               delta=1e-6)
        logger.info(
            "---------- Finish evaluating eval_greedy=True -----------------")

        logger.info(
            "---------- Start evaluating eval_greedy=False -----------------")
        edp = EvaluationDataPage.create_from_tensors_seq2slate(
            seq2slate_net, reward_net, ptb.training_input, eval_greedy=False)
        doubly_robust_estimator = DoublyRobustEstimator()
        _, inverse_propensity, _ = doubly_robust_estimator.estimate(edp)
        self.assertAlmostEqual(
            inverse_propensity.raw,
            (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3,
            delta=1e-6,
        )
        self.assertAlmostEqual(
            inverse_propensity.normalized,
            inverse_propensity.raw / avg_logged_reward,
            delta=1e-6,
        )
        logger.info(
            "---------- Finish evaluating eval_greedy=False -----------------")