def test_no_soft_update(self):
        model = Model()
        target_model = copy.deepcopy(model)

        for target_param, param in zip(model.parameters(),
                                       target_model.parameters()):
            self.assertIs(target_param, param)

        optimizer = torch.optim.Adam(model.parameters())

        x = torch.tensor([1, 2], dtype=torch.int64)
        emb = model(x)

        loss = emb.sum()

        loss.backward()
        optimizer.step()

        params = list(model.parameters())
        self.assertEqual(1, len(params))
        param = params[0].detach().numpy()

        trainer = RLTrainer(DiscreteActionModelParameters(rl=RLParameters()),
                            use_gpu=False)
        trainer._soft_update(model, target_model, 0.1)

        target_params = list(target_model.parameters())
        self.assertEqual(1, len(target_params))
        target_param = target_params[0].detach().numpy()

        npt.assert_array_equal(target_param, param)
Пример #2
0
    def __init__(
        self,
        q_network,
        q_network_target,
        reward_network,
        parameters: ContinuousActionModelParameters,
    ) -> None:
        self.double_q_learning = parameters.rainbow.double_q_learning
        self.minibatch_size = parameters.training.minibatch_size

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu=False,
            additional_feature_types=None,
            gradient_handler=None,
        )

        self.q_network = q_network
        self.q_network_target = q_network_target
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = reward_network
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)
Пример #3
0
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
    ) -> None:

        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters)
        self.num_action_features = get_num_output_features(
            action_normalization_parameters)
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        if parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0] = self.num_action_features

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           None)

        self.q_network = self._get_model(parameters.training)

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = self._get_model(parameters.training)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()
Пример #4
0
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        self.reward_shape = {}  # type: Dict[int, float]
        if parameters.rl.reward_boost is not None and self._actions is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_shape[i] = parameters.rl.reward_boost[k]

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingArchitectureQNetwork(
                parameters.training.layers, parameters.training.activations)
        else:
            self.q_network = GenericFeedForwardNetwork(
                parameters.training.layers, parameters.training.activations)
        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()
Пример #5
0
 def __init__(self,
              imitator,
              parameters: DiscreteActionModelParameters,
              use_gpu=False) -> None:
     self._set_optimizer(parameters.training.optimizer)
     self.minibatch_size = parameters.training.minibatch_size
     self.imitator = imitator
     self.imitator_optimizer = self.optimizer_func(
         imitator.parameters(),
         lr=parameters.training.learning_rate,
         weight_decay=parameters.training.l2_decay,
     )
     RLTrainer.__init__(self, parameters, use_gpu=use_gpu)
Пример #6
0
    def __init__(
        self,
        q_network,
        q_network_target,
        parameters: DiscreteActionModelParameters,
        use_gpu=False,
        metrics_to_score=None,
    ) -> None:
        RLTrainer.__init__(
            self,
            parameters,
            use_gpu=use_gpu,
            metrics_to_score=metrics_to_score,
            actions=parameters.actions,
        )

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.minibatch_size = parameters.training.minibatch_size
        self.minibatches_per_step = parameters.training.minibatches_per_step or 1
        self._actions = parameters.actions if parameters.actions is not None else []

        self.q_network = q_network
        self.q_network_target = q_network_target
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.rainbow.c51_l2_decay,
        )

        self.qmin = parameters.rainbow.qmin
        self.qmax = parameters.rainbow.qmax
        self.num_atoms = parameters.rainbow.num_atoms
        self.support = torch.linspace(self.qmin,
                                      self.qmax,
                                      self.num_atoms,
                                      device=self.device)

        self.reward_boosts = torch.zeros([1, len(self._actions)],
                                         device=self.device)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]
Пример #7
0
    def __init__(
        self,
        q_network,
        q_network_target,
        parameters: C51TrainerParameters,
        use_gpu=False,
        metrics_to_score=None,
    ) -> None:
        RLTrainer.__init__(
            self,
            parameters.rl,
            use_gpu=use_gpu,
            metrics_to_score=metrics_to_score,
            actions=parameters.actions,
        )

        self.double_q_learning = parameters.double_q_learning
        self.minibatch_size = parameters.minibatch_size
        self.minibatches_per_step = parameters.minibatches_per_step or 1
        self._actions = parameters.actions if parameters.actions is not None else []
        self.q_network = q_network
        self.q_network_target = q_network_target
        self.q_network_optimizer = self._get_optimizer(q_network,
                                                       parameters.optimizer)
        self.qmin = parameters.qmin
        self.qmax = parameters.qmax
        self.num_atoms = parameters.num_atoms
        self.support = torch.linspace(self.qmin,
                                      self.qmax,
                                      self.num_atoms,
                                      device=self.device)

        self.reward_boosts = torch.zeros([1, len(self._actions)],
                                         device=self.device)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]
Пример #8
0
    def __init__(
        self,
        actor_network,
        critic_network,
        parameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        min_action_range_tensor_serving: torch.Tensor,
        max_action_range_tensor_serving: torch.Tensor,
        use_gpu: bool = False,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters

        for param in self.action_normalization_parameters.values():
            assert param.feature_type == CONTINUOUS, (
                "DDPG Actor features must be set to continuous (set to "
                + param.feature_type
                + ")"
            )

        self.state_dim = get_num_output_features(state_normalization_parameters)
        self.action_dim = min_action_range_tensor_serving.shape[1]
        self.num_features = self.state_dim + self.action_dim

        # Actor generates actions between -1 and 1 due to tanh output layer so
        # convert actions to range [-1, 1] before training.
        self.min_action_range_tensor_training = torch.ones(1, self.action_dim) * -1
        self.max_action_range_tensor_training = torch.ones(1, self.action_dim)
        self.min_action_range_tensor_serving = min_action_range_tensor_serving
        self.max_action_range_tensor_serving = max_action_range_tensor_serving

        # Shared params
        self.warm_start_model_path = parameters.shared_training.warm_start_model_path
        self.minibatch_size = parameters.shared_training.minibatch_size
        self.final_layer_init = parameters.shared_training.final_layer_init
        self._set_optimizer(parameters.shared_training.optimizer)

        # Actor params
        self.actor_params = parameters.actor_training
        self.actor = actor_network
        self.actor_target = deepcopy(self.actor)
        self.actor_optimizer = self.optimizer_func(
            self.actor.parameters(),
            lr=self.actor_params.learning_rate,
            weight_decay=self.actor_params.l2_decay,
        )
        self.noise = OrnsteinUhlenbeckProcessNoise(self.action_dim)

        # Critic params
        self.critic_params = parameters.critic_training
        self.critic = self.q_network = critic_network
        self.critic_target = deepcopy(self.critic)
        self.critic_optimizer = self.optimizer_func(
            self.critic.parameters(),
            lr=self.critic_params.learning_rate,
            weight_decay=self.critic_params.l2_decay,
        )

        # ensure state and action IDs have no intersection
        overlapping_features = set(state_normalization_parameters.keys()) & set(
            action_normalization_parameters.keys()
        )
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: "
            + str(overlapping_features)
        )

        RLTrainer.__init__(self, parameters, use_gpu, None)

        self.min_action_range_tensor_training = self.min_action_range_tensor_training.type(
            self.dtype
        )
        self.max_action_range_tensor_training = self.max_action_range_tensor_training.type(
            self.dtype
        )
        self.min_action_range_tensor_serving = self.min_action_range_tensor_serving.type(
            self.dtype
        )
        self.max_action_range_tensor_serving = self.max_action_range_tensor_serving.type(
            self.dtype
        )
Пример #9
0
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            logger.info("Number of state features: " + str(self.num_features))
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
            actions=self._actions,
        )

        self.reward_boosts = torch.zeros([1,
                                          len(self._actions)]).type(self.dtype)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingQNetwork(
                parameters.training.layers,
                parameters.training.activations,
                use_batch_norm=parameters.training.use_batch_norm,
            )
        else:
            if parameters.training.cnn_parameters is None:
                self.q_network = FullyConnectedNetwork(
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )
            else:
                self.q_network = ConvolutionalNetwork(
                    parameters.training.cnn_parameters,
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )

        self.q_network_target = deepcopy(self.q_network)
        self.q_network._name = "training"
        self.q_network_target._name = "target"
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self._init_cpe_networks(parameters, use_all_avail_gpus)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
Пример #10
0
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters)
        self.num_action_features = get_num_output_features(
            action_normalization_parameters)
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0] = self.num_action_features

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture)

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)
Пример #11
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: Union[mt.State, torch.Tensor],
        actions: Union[mt.Action, torch.Tensor],
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[mt.FeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if max_num_actions:
                # Parametric model CPE
                state_action_pairs = mt.StateAction(state=states, action=actions)
                tiled_state = mt.FeatureVector(
                    states.float_features.repeat(1, max_num_actions).reshape(
                        -1, states.float_features.shape[1]
                    )
                )
                # Get Q-value of action taken
                possible_actions_state_concat = mt.StateAction(
                    state=tiled_state, action=possible_actions
                )

                # Parametric actions
                model_values = trainer.q_network(possible_actions_state_concat).q_value
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat
                ).q_value
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs
                ).q_value
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                ).q_value

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                if isinstance(states, mt.State):
                    states = mt.StateInput(state=states)

                rewards_and_metric_rewards = trainer.reward_network(states)

                # In case we reuse the modular for Q-network
                if hasattr(rewards_and_metric_rewards, "q_values"):
                    rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)
                    # Backward compatility
                    if hasattr(model_metrics_values, "q_values"):
                        model_metrics_values = model_metrics_values.q_values
                    model_metrics_values = model_metrics_values[:, num_actions:]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_mask=possible_actions_mask,
            )
Пример #12
0
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types: AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters
        )
        self.num_action_features = get_num_output_features(
            action_normalization_parameters
        )
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(state_normalization_parameters.keys()) & set(
            action_normalization_parameters.keys()
        )
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: "
            + str(overlapping_features)
        )

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0
            ] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0
            ] = self.num_action_features

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
        )

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture
        )

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations
        )
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(), lr=parameters.training.learning_rate
        )

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(self.q_network_target)
                self.reward_network = torch.nn.DataParallel(self.reward_network)
Пример #13
0
    def __init__(
        self,
        parameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        min_action_range_tensor_serving: torch.Tensor,
        max_action_range_tensor_serving: torch.Tensor,
        use_gpu: bool = False,
        additional_feature_types: AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters

        for param in self.action_normalization_parameters.values():
            assert param.feature_type == CONTINUOUS, (
                "DDPG Actor features must be set to continuous (set to "
                + param.feature_type
                + ")"
            )

        self.state_dim = get_num_output_features(state_normalization_parameters)
        self.action_dim = min_action_range_tensor_serving.shape[1]
        self.num_features = self.state_dim + self.action_dim

        # Actor generates actions between -1 and 1 due to tanh output layer so
        # convert actions to range [-1, 1] before training.
        self.min_action_range_tensor_training = torch.ones(1, self.action_dim) * -1
        self.max_action_range_tensor_training = torch.ones(1, self.action_dim)
        self.min_action_range_tensor_serving = min_action_range_tensor_serving
        self.max_action_range_tensor_serving = max_action_range_tensor_serving

        # Shared params
        self.warm_start_model_path = parameters.shared_training.warm_start_model_path
        self.minibatch_size = parameters.shared_training.minibatch_size
        self.final_layer_init = parameters.shared_training.final_layer_init
        self._set_optimizer(parameters.shared_training.optimizer)

        # Actor params
        self.actor_params = parameters.actor_training
        assert (
            self.actor_params.activations[-1] == "tanh"
        ), "Actor final layer activation must be tanh"
        self.actor_params.layers[0] = self.state_dim
        self.actor_params.layers[-1] = self.action_dim
        self.noise_generator = OrnsteinUhlenbeckProcessNoise(self.action_dim)
        self.actor = ActorNet(
            self.actor_params.layers,
            self.actor_params.activations,
            self.final_layer_init,
        )
        self.actor_target = deepcopy(self.actor)
        self.actor_optimizer = self.optimizer_func(
            self.actor.parameters(),
            lr=self.actor_params.learning_rate,
            weight_decay=self.actor_params.l2_decay,
        )
        self.noise = self.noise_generator

        # Critic params
        self.critic_params = parameters.critic_training
        self.critic_params.layers[0] = self.state_dim
        self.critic_params.layers[-1] = 1
        self.critic = self.q_network = CriticNet(
            self.critic_params.layers,
            self.critic_params.activations,
            self.final_layer_init,
            self.action_dim,
        )
        self.critic_target = deepcopy(self.critic)
        self.critic_optimizer = self.optimizer_func(
            self.critic.parameters(),
            lr=self.critic_params.learning_rate,
            weight_decay=self.critic_params.l2_decay,
        )

        # ensure state and action IDs have no intersection
        overlapping_features = set(state_normalization_parameters.keys()) & set(
            action_normalization_parameters.keys()
        )
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: "
            + str(overlapping_features)
        )

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types, None)

        self.min_action_range_tensor_training = self.min_action_range_tensor_training.type(
            self.dtype
        )
        self.max_action_range_tensor_training = self.max_action_range_tensor_training.type(
            self.dtype
        )
        self.min_action_range_tensor_serving = self.min_action_range_tensor_serving.type(
            self.dtype
        )
        self.max_action_range_tensor_serving = self.max_action_range_tensor_serving.type(
            self.dtype
        )

        if self.use_gpu:
            self.actor.cuda()
            self.actor_target.cuda()
            self.critic.cuda()
            self.critic_target.cuda()

            if use_all_avail_gpus:
                self.actor = nn.DataParallel(self.actor)
                self.actor_target = nn.DataParallel(self.actor_target)
                self.critic = nn.DataParallel(self.critic)
                self.critic_target = nn.DataParallel(self.critic_target)
Пример #14
0
    def __init__(
        self,
        parameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        min_action_range_tensor_serving: torch.Tensor,
        max_action_range_tensor_serving: torch.Tensor,
        use_gpu: bool = False,
        additional_feature_types: AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters

        self.state_dim = get_num_output_features(state_normalization_parameters)
        self.action_dim = min_action_range_tensor_serving.shape[1]

        # Actor generates actions between -1 and 1 due to tanh output layer so
        # convert actions to range [-1, 1] before training.
        self.min_action_range_tensor_training = torch.ones(1, self.action_dim) * -1
        self.max_action_range_tensor_training = torch.ones(1, self.action_dim)
        self.min_action_range_tensor_serving = min_action_range_tensor_serving
        self.max_action_range_tensor_serving = max_action_range_tensor_serving

        # Shared params
        self.warm_start_model_path = parameters.shared_training.warm_start_model_path
        self.minibatch_size = parameters.shared_training.minibatch_size
        self.final_layer_init = parameters.shared_training.final_layer_init
        self._set_optimizer(parameters.shared_training.optimizer)

        # Actor params
        self.actor_params = parameters.actor_training
        assert (
            self.actor_params.activations[-1] == "tanh"
        ), "Actor final layer activation must be tanh"
        self.actor_params.layers[0] = self.state_dim
        self.actor_params.layers[-1] = self.action_dim
        self.noise_generator = OrnsteinUhlenbeckProcessNoise(self.action_dim)
        self.actor = ActorNet(
            self.actor_params.layers,
            self.actor_params.activations,
            self.final_layer_init,
        )
        self.actor_target = deepcopy(self.actor)
        self.actor_optimizer = self.optimizer_func(
            self.actor.parameters(),
            lr=self.actor_params.learning_rate,
            weight_decay=self.actor_params.l2_decay,
        )
        self.noise = self.noise_generator

        # Critic params
        self.critic_params = parameters.critic_training
        self.critic_params.layers[0] = self.state_dim
        self.critic_params.layers[-1] = 1
        self.critic = CriticNet(
            self.critic_params.layers,
            self.critic_params.activations,
            self.final_layer_init,
            self.action_dim,
        )
        self.critic_target = deepcopy(self.critic)
        self.critic_optimizer = self.optimizer_func(
            self.critic.parameters(),
            lr=self.critic_params.learning_rate,
            weight_decay=self.critic_params.l2_decay,
        )

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types, None)

        self.min_action_range_tensor_training = self.min_action_range_tensor_training.type(
            self.dtype
        )
        self.max_action_range_tensor_training = self.max_action_range_tensor_training.type(
            self.dtype
        )
        self.min_action_range_tensor_serving = self.min_action_range_tensor_serving.type(
            self.dtype
        )
        self.max_action_range_tensor_serving = self.max_action_range_tensor_serving.type(
            self.dtype
        )

        if self.use_gpu:
            self.actor.cuda()
            self.actor_target.cuda()
            self.critic.cuda()
            self.critic_target.cuda()

            if use_all_avail_gpus:
                self.actor = nn.DataParallel(self.actor)
                self.actor_target = nn.DataParallel(self.actor_target)
                self.critic = nn.DataParallel(self.critic)
                self.critic_target = nn.DataParallel(self.critic_target)
Пример #15
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: torch.Tensor,
        actions: torch.Tensor,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_state_concat: Optional[torch.Tensor],
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if possible_actions_state_concat is not None:
                state_action_pairs = torch.cat((states, actions), dim=1)

                # Parametric actions
                rewards = rewards
                model_values = trainer.q_network(possible_actions_state_concat)
                assert (
                    model_values.shape[0] *
                    model_values.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_values.shape) + " != " +
                        str(possible_actions_mask.shape))
                model_values = model_values.reshape(
                    possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat)
                assert (
                    model_rewards.shape[0] *
                    model_rewards.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_rewards.shape) +
                        " != " + str(possible_actions_mask.shape))
                model_rewards = model_rewards.reshape(
                    possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs)
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs)

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) <
                    1e-3).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: " + str(model_values.shape) + " != " +
                    str(actions.shape))
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: " + str(model_values.shape) + " != " +
                    str(possible_actions_mask.shape))
                model_values_for_logged_action = torch.sum(model_values *
                                                           action_mask,
                                                           dim=1,
                                                           keepdim=True)

                rewards_and_metric_rewards = trainer.reward_network(states)

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: " + str(model_rewards.shape) + " != " +
                    str(actions.shape))
                model_rewards_for_logged_action = torch.sum(model_rewards *
                                                            action_mask,
                                                            dim=1,
                                                            keepdim=True)

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: " + str(model_metrics.shape) +
                    " " + str(num_actions))
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(
                        states)[:, num_actions:]
                    assert model_metrics_values.shape[
                        1] == num_actions * num_metrics, (
                            "Invalid shape: " +
                            str(model_metrics_values.shape[1]) + " != " +
                            str(actions.shape[1] * num_metrics))

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] *
                                action_mask,
                                dim=1,
                                keepdim=True,
                            ))

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:,
                                                     metric_start:metric_end] *
                                action_mask,
                                dim=1,
                                keepdim=True,
                            ))
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1)
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1)

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=
                model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(model_values,
                                                  possible_actions_mask,
                                                  trainer.rl_temperature),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_state_concat=possible_actions_state_concat,
                possible_actions_mask=possible_actions_mask,
            )
Пример #16
0
def train_gym_offline_rl(
    gym_env: OpenAIGymEnvironment,
    replay_buffer: OpenAIGymMemoryPool,
    model_type: str,
    trainer: RLTrainer,
    predictor: OnPolicyPredictor,
    test_run_name: str,
    score_bar: Optional[float],
    max_steps: int,
    avg_over_num_episodes: int,
    offline_train_epochs: int,
    num_batch_per_epoch: Optional[int],
    bcq_imitator_hyper_params: Optional[Dict[str, Any]] = None,
):
    if num_batch_per_epoch is None:
        num_batch_per_epoch = replay_buffer.size // trainer.minibatch_size
    assert num_batch_per_epoch > 0, "The size of replay buffer is not sufficient"

    logger.info(
        "{} offline transitions in replay buffer.\n"
        "Training will take {} epochs, with each epoch having {} mini-batches"
        " and each mini-batch having {} samples".format(
            replay_buffer.size,
            offline_train_epochs,
            num_batch_per_epoch,
            trainer.minibatch_size,
        )
    )

    avg_reward_history, epoch_history = [], []

    # Pre-train a GBDT imitator if doing batch constrained q-learning in Gym
    if getattr(trainer, "bcq", None):
        assert bcq_imitator_hyper_params is not None
        gbdt = GradientBoostingClassifier(
            n_estimators=bcq_imitator_hyper_params["gbdt_trees"],
            max_depth=bcq_imitator_hyper_params["max_depth"],
        )
        samples = replay_buffer.sample_memories(replay_buffer.size, model_type)
        X, y = samples.states.numpy(), torch.max(samples.actions, dim=1)[1].numpy()
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
        logger.info("Fitting GBDT...")
        gbdt.fit(X_train, y_train)
        train_score = round(gbdt.score(X_train, y_train) * 100, 1)
        test_score = round(gbdt.score(X_test, y_test) * 100, 1)
        logger.info(
            "GBDT train accuracy {}% || test accuracy {}%".format(
                train_score, test_score
            )
        )
        trainer.bcq_imitator = gbdt.predict_proba  # type: ignore

    # Offline training
    for i_epoch in range(offline_train_epochs):
        for _ in range(num_batch_per_epoch):
            samples = replay_buffer.sample_memories(trainer.minibatch_size, model_type)
            samples.set_device(trainer.device)
            trainer.train(samples)

        batch_td_loss = float(
            torch.mean(
                torch.tensor(
                    [stat.td_loss for stat in trainer.loss_reporter.incoming_stats]
                )
            )
        )
        trainer.loss_reporter.flush()
        logger.info(
            "Average TD loss: {} in epoch {}".format(batch_td_loss, i_epoch + 1)
        )

        # test model performance for this epoch
        avg_rewards, avg_discounted_rewards = gym_env.run_ep_n_times(
            avg_over_num_episodes, predictor, test=True, max_steps=max_steps
        )
        avg_reward_history.append(avg_rewards)

        # For offline training, use epoch number as timestep history since
        # we have a fixed batch of data to count epochs over.
        epoch_history.append(i_epoch)
        logger.info(
            "Achieved an average reward score of {} over {} evaluations"
            " after epoch {}.".format(avg_rewards, avg_over_num_episodes, i_epoch)
        )
        if score_bar is not None and avg_rewards > score_bar:
            logger.info(
                "Avg. reward history for {}: {}".format(
                    test_run_name, avg_reward_history
                )
            )
            return avg_reward_history, epoch_history, trainer, predictor, gym_env

    logger.info(
        "Avg. reward history for {}: {}".format(test_run_name, avg_reward_history)
    )
    return avg_reward_history, epoch_history, trainer, predictor, gym_env
Пример #17
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: torch.Tensor,
        actions: torch.Tensor,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_state_concat: Optional[torch.Tensor],
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if possible_actions_state_concat is not None:
                state_action_pairs = torch.cat((states, actions), dim=1)

                # Parametric actions
                rewards = rewards
                model_values = trainer.q_network(possible_actions_state_concat)
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(possible_actions_state_concat)
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(state_action_pairs)
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                )

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                rewards_and_metric_rewards = trainer.reward_network(states)

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)[
                        :, num_actions:
                    ]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_state_concat=possible_actions_state_concat,
                possible_actions_mask=possible_actions_mask,
            )