Exemple #1
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: Union[mt.State, torch.Tensor],
        actions: Union[mt.Action, torch.Tensor],
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[mt.FeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if max_num_actions:
                # Parametric model CPE
                state_action_pairs = mt.StateAction(state=states, action=actions)
                tiled_state = mt.FeatureVector(
                    states.float_features.repeat(1, max_num_actions).reshape(
                        -1, states.float_features.shape[1]
                    )
                )
                # Get Q-value of action taken
                possible_actions_state_concat = mt.StateAction(
                    state=tiled_state, action=possible_actions
                )

                # Parametric actions
                model_values = trainer.q_network(possible_actions_state_concat).q_value
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat
                ).q_value
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs
                ).q_value
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                ).q_value

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                if isinstance(states, mt.State):
                    states = mt.StateInput(state=states)

                rewards_and_metric_rewards = trainer.reward_network(states)

                # In case we reuse the modular for Q-network
                if hasattr(rewards_and_metric_rewards, "q_values"):
                    rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)
                    # Backward compatility
                    if hasattr(model_metrics_values, "q_values"):
                        model_metrics_values = model_metrics_values.q_values
                    model_metrics_values = model_metrics_values[:, num_actions:]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_mask=possible_actions_mask,
            )
Exemple #2
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: torch.Tensor,
        actions: torch.Tensor,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_state_concat: Optional[torch.Tensor],
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if possible_actions_state_concat is not None:
                state_action_pairs = torch.cat((states, actions), dim=1)

                # Parametric actions
                rewards = rewards
                model_values = trainer.q_network(possible_actions_state_concat)
                assert (
                    model_values.shape[0] *
                    model_values.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_values.shape) + " != " +
                        str(possible_actions_mask.shape))
                model_values = model_values.reshape(
                    possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat)
                assert (
                    model_rewards.shape[0] *
                    model_rewards.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_rewards.shape) +
                        " != " + str(possible_actions_mask.shape))
                model_rewards = model_rewards.reshape(
                    possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs)
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs)

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) <
                    1e-3).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: " + str(model_values.shape) + " != " +
                    str(actions.shape))
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: " + str(model_values.shape) + " != " +
                    str(possible_actions_mask.shape))
                model_values_for_logged_action = torch.sum(model_values *
                                                           action_mask,
                                                           dim=1,
                                                           keepdim=True)

                rewards_and_metric_rewards = trainer.reward_network(states)

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: " + str(model_rewards.shape) + " != " +
                    str(actions.shape))
                model_rewards_for_logged_action = torch.sum(model_rewards *
                                                            action_mask,
                                                            dim=1,
                                                            keepdim=True)

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: " + str(model_metrics.shape) +
                    " " + str(num_actions))
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(
                        states)[:, num_actions:]
                    assert model_metrics_values.shape[
                        1] == num_actions * num_metrics, (
                            "Invalid shape: " +
                            str(model_metrics_values.shape[1]) + " != " +
                            str(actions.shape[1] * num_metrics))

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] *
                                action_mask,
                                dim=1,
                                keepdim=True,
                            ))

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:,
                                                     metric_start:metric_end] *
                                action_mask,
                                dim=1,
                                keepdim=True,
                            ))
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1)
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1)

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=
                model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(model_values,
                                                  possible_actions_mask,
                                                  trainer.rl_temperature),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_state_concat=possible_actions_state_concat,
                possible_actions_mask=possible_actions_mask,
            )
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: torch.Tensor,
        actions: torch.Tensor,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_state_concat: Optional[torch.Tensor],
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if possible_actions_state_concat is not None:
                state_action_pairs = torch.cat((states, actions), dim=1)

                # Parametric actions
                rewards = rewards
                model_values = trainer.q_network(possible_actions_state_concat)
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(possible_actions_state_concat)
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(state_action_pairs)
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                )

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                rewards_and_metric_rewards = trainer.reward_network(states)

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)[
                        :, num_actions:
                    ]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_state_concat=possible_actions_state_concat,
                possible_actions_mask=possible_actions_mask,
            )