Пример #1
0
    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """

        # FIXME: this only works for one-hot encoded actions
        action_dim = get_num_output_features(
            self.action_normalization_data.dense_normalization_parameters)
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module(), max_num_actions=action_dim)
        else:
            sampler = SoftmaxActionSampler(
                temperature=self.rl_parameters.temperature)
            scorer = parametric_dqn_scorer(max_num_actions=action_dim,
                                           q_network=self._q_network)
            return Policy(scorer=scorer, sampler=sampler)
Пример #2
0
 def build_actor(
     self,
     state_normalization_data: NormalizationData,
     num_actions: int,
 ) -> ModelBase:
     state_dim = get_num_output_features(
         state_normalization_data.dense_normalization_parameters)
     return FullyConnectedActor(
         state_dim=state_dim,
         action_dim=num_actions,
         sizes=self.sizes,
         activations=self.activations,
         use_batch_norm=self.use_batch_norm,
         action_activation=self.action_activation,
         exploration_variance=self.exploration_variance,
     )
Пример #3
0
    def __init__(
        self,
        normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool,
    ) -> None:
        super().__init__()
        self.num_output_features = get_num_output_features(
            normalization_parameters)
        feature_types = {
            norm_param.feature_type
            for norm_param in normalization_parameters.values()
        }
        assert (
            len(feature_types) == 1
        ), "All dimensions of actions should have the same preprocessing"
        self.feature_type = list(feature_types)[0]
        assert self.feature_type in {
            DISCRETE_ACTION,
            CONTINUOUS_ACTION,
            DO_NOT_PREPROCESS,
        }, f"{self.feature_type} is not DISCRETE_ACTION, CONTINUOUS_ACTION or DO_NOT_PREPROCESS"

        self.device = torch.device("cuda" if use_gpu else "cpu")

        if self.feature_type == CONTINUOUS_ACTION:
            sorted_features = sorted(normalization_parameters.keys())
            self.min_serving_value = torch.tensor(
                [
                    normalization_parameters[f].min_value
                    for f in sorted_features
                ],
                device=self.device,
            ).float()
            self.scaling_factor = torch.tensor(
                [
                    (
                        # pyre-fixme[58]: `-` is not supported for operand types
                        #  `Optional[float]` and `Optional[float]`.
                        normalization_parameters[f].max_value -
                        normalization_parameters[f].min_value) /
                    (2 * (1 - EPS)) for f in sorted_features
                ],
                device=self.device,
            ).float()
            self.almost_one = torch.tensor(1.0 - EPS,
                                           device=self.device).float()
Пример #4
0
    def __init__(
        self,
        normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool,
    ) -> None:
        super().__init__()
        self.num_output_features = get_num_output_features(
            normalization_parameters)

        feature_types = {
            norm_param.feature_type
            for norm_param in normalization_parameters.values()
        }
        assert (
            len(feature_types) == 1
        ), "All dimensions of actions should have the same preprocessing"
        self.feature_type = list(feature_types)[0]
        assert self.feature_type in {
            CONTINUOUS_ACTION,
            DO_NOT_PREPROCESS,
        }, "Only support CONTINUOUS_ACTION & DO_NOT_PREPROCESS"

        self.device = torch.device(
            "cuda" if use_gpu else "cpu")  # type: ignore

        if self.feature_type == CONTINUOUS_ACTION:
            sorted_features = sorted(normalization_parameters.keys())
            self.min_serving_value = torch.tensor(
                [
                    normalization_parameters[f].min_value
                    for f in sorted_features
                ],
                device=self.device,
            )
            self.scaling_factor = torch.tensor(
                [
                    (
                        normalization_parameters[f].max_value  # type: ignore
                        - normalization_parameters[f].min_value  # type: ignore
                    ) / (2 * (1 - EPS)) for f in sorted_features
                ],
                device=self.device,
            )
Пример #5
0
    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> MDNRNNTrainer:
        memory_network = MemoryNetwork(
            state_dim=get_num_output_features(normalization_data_map[
                NormalizationKey.STATE].dense_normalization_parameters),
            action_dim=self.trainer_param.action_dim,
            num_hiddens=self.trainer_param.hidden_size,
            num_hidden_layers=self.trainer_param.num_hidden_layers,
            num_gaussians=self.trainer_param.num_gaussians,
        )
        if use_gpu:
            memory_network = memory_network.cuda()

        return MDNRNNTrainer(memory_network=memory_network,
                             params=self.trainer_param)
Пример #6
0
    def build_trainer(self) -> CEMTrainer:
        world_model_manager: WorldModel = WorldModel(
            trainer_param=self.trainer_param.mdnrnn)
        world_model_manager.initialize_trainer(
            self.use_gpu,
            self.reward_options,
            # pyre-fixme[6]: Expected `Dict[str,
            #  reagent.parameters.NormalizationData]` for 3rd param but got
            #  `Optional[typing.Dict[str, reagent.parameters.NormalizationData]]`.
            # pyre-fixme[6]: Expected `Dict[str,
            #  reagent.parameters.NormalizationData]` for 3rd param but got
            #  `Optional[typing.Dict[str, reagent.parameters.NormalizationData]]`.
            self._normalization_data_map,
        )
        world_model_trainers = [
            world_model_manager.build_trainer()
            for _ in range(self.trainer_param.num_world_models)
        ]
        world_model_nets = [
            trainer.memory_network for trainer in world_model_trainers
        ]
        terminal_effective = self.trainer_param.mdnrnn.not_terminal_loss_weight > 0

        action_normalization_parameters = (
            self.action_normalization_data.dense_normalization_parameters)
        sorted_action_norm_vals = list(
            action_normalization_parameters.values())
        discrete_action = sorted_action_norm_vals[
            0].feature_type != CONTINUOUS_ACTION
        action_upper_bounds, action_lower_bounds = None, None
        if not discrete_action:
            action_upper_bounds = np.array(
                [v.max_value for v in sorted_action_norm_vals])
            action_lower_bounds = np.array(
                [v.min_value for v in sorted_action_norm_vals])

        cem_planner_network = CEMPlannerNetwork(
            mem_net_list=world_model_nets,
            cem_num_iterations=self.trainer_param.cem_num_iterations,
            cem_population_size=self.trainer_param.cem_population_size,
            ensemble_population_size=self.trainer_param.
            ensemble_population_size,
            num_elites=self.trainer_param.num_elites,
            plan_horizon_length=self.trainer_param.plan_horizon_length,
            state_dim=get_num_output_features(
                self.state_normalization_data.dense_normalization_parameters),
            action_dim=get_num_output_features(
                self.action_normalization_data.dense_normalization_parameters),
            discrete_action=discrete_action,
            terminal_effective=terminal_effective,
            gamma=self.trainer_param.rl.gamma,
            alpha=self.trainer_param.alpha,
            epsilon=self.trainer_param.epsilon,
            action_upper_bounds=action_upper_bounds,
            action_lower_bounds=action_lower_bounds,
        )
        # store for building policy
        # pyre-fixme[16]: `CrossEntropyMethod` has no attribute `discrete_action`.
        self.discrete_action = discrete_action
        # pyre-fixme[16]: `CrossEntropyMethod` has no attribute `cem_planner_network`.
        self.cem_planner_network = cem_planner_network
        logger.info(
            f"Built CEM network with discrete action = {discrete_action}, "
            f"action_upper_bound={action_upper_bounds}, "
            f"action_lower_bounds={action_lower_bounds}")
        return CEMTrainer(
            cem_planner_network=cem_planner_network,
            world_model_trainers=world_model_trainers,
            parameters=self.trainer_param,
            use_gpu=self.use_gpu,
        )
Пример #7
0
    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> CEMTrainer:
        # pyre-fixme[45]: Cannot instantiate abstract class `WorldModel`.
        world_model_manager: WorldModel = WorldModel(
            trainer_param=self.trainer_param.mdnrnn)
        world_model_manager.build_trainer(
            use_gpu=use_gpu,
            reward_options=reward_options,
            normalization_data_map=normalization_data_map,
        )
        world_model_trainers = [
            world_model_manager.build_trainer(normalization_data_map,
                                              reward_options=reward_options,
                                              use_gpu=use_gpu)
            for _ in range(self.trainer_param.num_world_models)
        ]
        world_model_nets = [
            trainer.memory_network for trainer in world_model_trainers
        ]
        terminal_effective = self.trainer_param.mdnrnn.not_terminal_loss_weight > 0

        action_normalization_parameters = normalization_data_map[
            NormalizationKey.ACTION].dense_normalization_parameters
        sorted_action_norm_vals = list(
            action_normalization_parameters.values())
        discrete_action = sorted_action_norm_vals[
            0].feature_type != CONTINUOUS_ACTION
        action_upper_bounds, action_lower_bounds = None, None
        if not discrete_action:
            action_upper_bounds = np.array(
                [v.max_value for v in sorted_action_norm_vals])
            action_lower_bounds = np.array(
                [v.min_value for v in sorted_action_norm_vals])

        cem_planner_network = CEMPlannerNetwork(
            mem_net_list=world_model_nets,
            cem_num_iterations=self.trainer_param.cem_num_iterations,
            cem_population_size=self.trainer_param.cem_population_size,
            ensemble_population_size=self.trainer_param.
            ensemble_population_size,
            num_elites=self.trainer_param.num_elites,
            plan_horizon_length=self.trainer_param.plan_horizon_length,
            state_dim=get_num_output_features(normalization_data_map[
                NormalizationKey.STATE].dense_normalization_parameters),
            action_dim=get_num_output_features(normalization_data_map[
                NormalizationKey.ACTION].dense_normalization_parameters),
            discrete_action=discrete_action,
            terminal_effective=terminal_effective,
            gamma=self.trainer_param.rl.gamma,
            alpha=self.trainer_param.alpha,
            epsilon=self.trainer_param.epsilon,
            action_upper_bounds=action_upper_bounds,
            action_lower_bounds=action_lower_bounds,
        )
        # store for building policy
        # pyre-fixme[16]: `CrossEntropyMethod` has no attribute `discrete_action`.
        self.discrete_action = discrete_action
        logger.info(
            f"Built CEM network with discrete action = {discrete_action}, "
            f"action_upper_bound={action_upper_bounds}, "
            f"action_lower_bounds={action_lower_bounds}")
        return CEMTrainer(
            cem_planner_network=cem_planner_network,
            world_model_trainers=world_model_trainers,
            parameters=self.trainer_param,
        )
Пример #8
0
 def _get_input_dim(self,
                    state_normalization_data: NormalizationData) -> int:
     return get_num_output_features(
         state_normalization_data.dense_normalization_parameters)
Пример #9
0
    def get_modular_sarsa_trainer_reward_boost(
        self,
        environment,
        reward_shape,
        dueling,
        categorical,
        quantile,
        use_gpu=False,
        use_all_avail_gpus=False,
        clip_grad_norm=None,
    ):
        assert not quantile or not categorical
        parameters = self.get_sarsa_parameters(
            environment, reward_shape, dueling, categorical, quantile, clip_grad_norm
        )

        def make_dueling_dqn(num_atoms=None):
            return models.DuelingQNetwork.make_fully_connected(
                state_dim=get_num_output_features(environment.normalization),
                action_dim=len(environment.ACTIONS),
                layers=parameters.training.layers[1:-1],
                activations=parameters.training.activations[:-1],
                num_atoms=num_atoms,
            )

        if quantile:
            if dueling:
                q_network = make_dueling_dqn(num_atoms=parameters.rainbow.num_atoms)

            else:
                q_network = models.FullyConnectedDQN(
                    state_dim=get_num_output_features(environment.normalization),
                    action_dim=len(environment.ACTIONS),
                    num_atoms=parameters.rainbow.num_atoms,
                    sizes=parameters.training.layers[1:-1],
                    activations=parameters.training.activations[:-1],
                )
        elif categorical:
            assert not dueling
            distributional_network = models.FullyConnectedDQN(
                state_dim=get_num_output_features(environment.normalization),
                action_dim=len(environment.ACTIONS),
                num_atoms=parameters.rainbow.num_atoms,
                sizes=parameters.training.layers[1:-1],
                activations=parameters.training.activations[:-1],
            )
            q_network = models.CategoricalDQN(
                distributional_network,
                qmin=-100,
                qmax=200,
                num_atoms=parameters.rainbow.num_atoms,
            )
        else:
            if dueling:
                q_network = make_dueling_dqn()
            else:
                q_network = models.FullyConnectedDQN(
                    state_dim=get_num_output_features(environment.normalization),
                    action_dim=len(environment.ACTIONS),
                    sizes=parameters.training.layers[1:-1],
                    activations=parameters.training.activations[:-1],
                )

        q_network_cpe, q_network_cpe_target, reward_network = None, None, None

        if parameters.evaluation and parameters.evaluation.calc_cpe_in_training:
            q_network_cpe = models.FullyConnectedDQN(
                state_dim=get_num_output_features(environment.normalization),
                action_dim=len(environment.ACTIONS),
                sizes=parameters.training.layers[1:-1],
                activations=parameters.training.activations[:-1],
            )
            q_network_cpe_target = q_network_cpe.get_target_network()
            reward_network = models.FullyConnectedDQN(
                state_dim=get_num_output_features(environment.normalization),
                action_dim=len(environment.ACTIONS),
                sizes=parameters.training.layers[1:-1],
                activations=parameters.training.activations[:-1],
            )

        if use_gpu:
            q_network = q_network.cuda()
            if parameters.evaluation.calc_cpe_in_training:
                reward_network = reward_network.cuda()
                q_network_cpe = q_network_cpe.cuda()
                q_network_cpe_target = q_network_cpe_target.cuda()
            if use_all_avail_gpus and not categorical:
                q_network = q_network.get_distributed_data_parallel_model()
                reward_network = reward_network.get_distributed_data_parallel_model()
                q_network_cpe = q_network_cpe.get_distributed_data_parallel_model()
                q_network_cpe_target = (
                    q_network_cpe_target.get_distributed_data_parallel_model()
                )

        if quantile:
            parameters = QRDQNTrainerParameters.from_discrete_action_model_parameters(
                parameters
            )
            trainer = QRDQNTrainer(
                q_network,
                q_network.get_target_network(),
                parameters,
                use_gpu,
                reward_network=reward_network,
                q_network_cpe=q_network_cpe,
                q_network_cpe_target=q_network_cpe_target,
            )
        elif categorical:
            parameters = C51TrainerParameters.from_discrete_action_model_parameters(
                parameters
            )
            trainer = C51Trainer(
                q_network, q_network.get_target_network(), parameters, use_gpu
            )
        else:
            parameters = DQNTrainerParameters.from_discrete_action_model_parameters(
                parameters
            )
            trainer = DQNTrainer(
                q_network,
                q_network.get_target_network(),
                reward_network,
                parameters,
                use_gpu,
                q_network_cpe=q_network_cpe,
                q_network_cpe_target=q_network_cpe_target,
            )
        return trainer
Пример #10
0
 def _get_input_dim(
     self, state_normalization_parameters: Dict[int, NormalizationParameters]
 ) -> int:
     return get_num_output_features(state_normalization_parameters)
Пример #11
0
def create_dqn_trainer_from_params(
    model: DiscreteActionModelParameters,
    normalization_parameters: Dict[int, NormalizationParameters],
    use_gpu: bool = False,
    use_all_avail_gpus: bool = False,
    metrics_to_score=None,
):
    metrics_to_score = metrics_to_score or []

    if model.rainbow.quantile:
        q_network = QuantileDQN(
            state_dim=get_num_output_features(normalization_parameters),
            action_dim=len(model.actions),
            num_atoms=model.rainbow.num_atoms,
            sizes=model.training.layers[1:-1],
            activations=model.training.activations[:-1],
            dropout_ratio=model.training.dropout_ratio,
        )
    elif model.rainbow.categorical:
        q_network = CategoricalDQN(  # type: ignore
            state_dim=get_num_output_features(normalization_parameters),
            action_dim=len(model.actions),
            num_atoms=model.rainbow.num_atoms,
            qmin=model.rainbow.qmin,
            qmax=model.rainbow.qmax,
            sizes=model.training.layers[1:-1],
            activations=model.training.activations[:-1],
            dropout_ratio=model.training.dropout_ratio,
            use_gpu=use_gpu,
        )
    elif model.rainbow.dueling_architecture:
        q_network = DuelingQNetwork(  # type: ignore
            layers=[get_num_output_features(normalization_parameters)] +
            model.training.layers[1:-1] + [len(model.actions)],
            activations=model.training.activations,
        )
    else:
        q_network = FullyConnectedDQN(  # type: ignore
            state_dim=get_num_output_features(normalization_parameters),
            action_dim=len(model.actions),
            sizes=model.training.layers[1:-1],
            activations=model.training.activations[:-1],
            dropout_ratio=model.training.dropout_ratio,
        )

    if use_gpu and torch.cuda.is_available():
        q_network = q_network.cuda()

    q_network_target = q_network.get_target_network()

    reward_network, q_network_cpe, q_network_cpe_target = None, None, None
    if model.evaluation.calc_cpe_in_training:
        # Metrics + reward
        num_output_nodes = (len(metrics_to_score) + 1) * len(model.actions)
        reward_network = FullyConnectedDQN(
            state_dim=get_num_output_features(normalization_parameters),
            action_dim=num_output_nodes,
            sizes=model.training.layers[1:-1],
            activations=model.training.activations[:-1],
            dropout_ratio=model.training.dropout_ratio,
        )
        q_network_cpe = FullyConnectedDQN(
            state_dim=get_num_output_features(normalization_parameters),
            action_dim=num_output_nodes,
            sizes=model.training.layers[1:-1],
            activations=model.training.activations[:-1],
            dropout_ratio=model.training.dropout_ratio,
        )

        if use_gpu and torch.cuda.is_available():
            reward_network.cuda()
            q_network_cpe.cuda()

        q_network_cpe_target = q_network_cpe.get_target_network()

    if (use_all_avail_gpus and not model.rainbow.categorical
            and not model.rainbow.quantile):
        q_network = q_network.get_distributed_data_parallel_model()
        reward_network = (reward_network.get_distributed_data_parallel_model()
                          if reward_network else None)
        q_network_cpe = (q_network_cpe.get_distributed_data_parallel_model()
                         if q_network_cpe else None)

    if model.rainbow.quantile:
        assert (not use_all_avail_gpus
                ), "use_all_avail_gpus not implemented for distributional RL"
        parameters = QRDQNTrainerParameters.from_discrete_action_model_parameters(
            model)
        return QRDQNTrainer(
            q_network,
            q_network_target,
            parameters,
            use_gpu,
            metrics_to_score=metrics_to_score,
            reward_network=reward_network,
            q_network_cpe=q_network_cpe,
            q_network_cpe_target=q_network_cpe_target,
        )

    elif model.rainbow.categorical:
        assert (not use_all_avail_gpus
                ), "use_all_avail_gpus not implemented for distributional RL"
        return C51Trainer(
            q_network,
            q_network_target,
            C51TrainerParameters.from_discrete_action_model_parameters(model),
            use_gpu,
            metrics_to_score=metrics_to_score,
        )

    else:
        parameters = DQNTrainerParameters.from_discrete_action_model_parameters(
            model)
        return DQNTrainer(
            q_network,
            q_network_target,
            reward_network,
            parameters,
            use_gpu,
            q_network_cpe=q_network_cpe,
            q_network_cpe_target=q_network_cpe_target,
            metrics_to_score=metrics_to_score,
        )
Пример #12
0
def get_sac_trainer(
    env: OpenAIGymEnvironment,
    rl_parameters: RLParameters,
    trainer_parameters: SACTrainerParameters,
    critic_training: FeedForwardParameters,
    actor_training: FeedForwardParameters,
    sac_value_training: Optional[FeedForwardParameters],
    use_gpu: bool,
) -> SACTrainer:
    assert rl_parameters == trainer_parameters.rl
    state_dim = get_num_output_features(env.normalization)
    action_dim = get_num_output_features(env.normalization_action)
    q1_network = FullyConnectedParametricDQN(state_dim, action_dim,
                                             critic_training.layers,
                                             critic_training.activations)
    q2_network = None
    # TODO:
    # if trainer_parameters.use_2_q_functions:
    #     q2_network = FullyConnectedParametricDQN(
    #         state_dim,
    #         action_dim,
    #         critic_training.layers,
    #         critic_training.activations,
    #     )
    value_network = None
    if sac_value_training:
        value_network = FullyConnectedNetwork(
            [state_dim] + sac_value_training.layers + [1],
            sac_value_training.activations + ["linear"],
        )
    actor_network = GaussianFullyConnectedActor(state_dim, action_dim,
                                                actor_training.layers,
                                                actor_training.activations)

    min_action_range_tensor_training = torch.full((1, action_dim), -1 + 1e-6)
    max_action_range_tensor_training = torch.full((1, action_dim), 1 - 1e-6)
    min_action_range_tensor_serving = (
        torch.from_numpy(env.action_space.low).float().unsqueeze(
            dim=0)  # type: ignore
    )
    max_action_range_tensor_serving = (
        torch.from_numpy(env.action_space.high).float().unsqueeze(
            dim=0)  # type: ignore
    )

    if use_gpu:
        q1_network.cuda()
        if q2_network:
            q2_network.cuda()
        if value_network:
            value_network.cuda()
        actor_network.cuda()

        min_action_range_tensor_training = min_action_range_tensor_training.cuda(
        )
        max_action_range_tensor_training = max_action_range_tensor_training.cuda(
        )
        min_action_range_tensor_serving = min_action_range_tensor_serving.cuda(
        )
        max_action_range_tensor_serving = max_action_range_tensor_serving.cuda(
        )

    return SACTrainer(
        q1_network,
        actor_network,
        trainer_parameters,
        use_gpu=use_gpu,
        value_network=value_network,
        q2_network=q2_network,
        min_action_range_tensor_training=min_action_range_tensor_training,
        max_action_range_tensor_training=max_action_range_tensor_training,
        min_action_range_tensor_serving=min_action_range_tensor_serving,
        max_action_range_tensor_serving=max_action_range_tensor_serving,
    )
Пример #13
0
    def get_sac_trainer(
        self,
        env,
        use_gpu,
        use_2_q_functions=False,
        logged_action_uniform_prior=True,
        constrain_action_sum=False,
        use_value_network=True,
        use_alpha_optimizer=True,
        entropy_temperature=None,
    ):
        q_network_params = FeedForwardParameters(layers=[128, 64],
                                                 activations=["relu", "relu"])
        value_network_params = FeedForwardParameters(
            layers=[128, 64], activations=["relu", "relu"])
        actor_network_params = FeedForwardParameters(
            layers=[128, 64], activations=["relu", "relu"])

        state_dim = get_num_output_features(env.normalization)
        action_dim = get_num_output_features(
            env.normalization_continuous_action)
        q1_network = FullyConnectedParametricDQN(state_dim, action_dim,
                                                 q_network_params.layers,
                                                 q_network_params.activations)
        q2_network = None
        if use_2_q_functions:
            q2_network = FullyConnectedParametricDQN(
                state_dim,
                action_dim,
                q_network_params.layers,
                q_network_params.activations,
            )
        if constrain_action_sum:
            actor_network = DirichletFullyConnectedActor(
                state_dim,
                action_dim,
                actor_network_params.layers,
                actor_network_params.activations,
            )
        else:
            actor_network = GaussianFullyConnectedActor(
                state_dim,
                action_dim,
                actor_network_params.layers,
                actor_network_params.activations,
            )

        value_network = None
        if use_value_network:
            value_network = FullyConnectedNetwork(
                [state_dim] + value_network_params.layers + [1],
                value_network_params.activations + ["linear"],
            )

        if use_gpu:
            q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            actor_network.cuda()

        parameters = SACTrainerParameters(
            rl=RLParameters(gamma=DISCOUNT, target_update_rate=0.5),
            minibatch_size=self.minibatch_size,
            q_network_optimizer=OptimizerParameters(),
            value_network_optimizer=OptimizerParameters(),
            actor_network_optimizer=OptimizerParameters(),
            alpha_optimizer=OptimizerParameters()
            if use_alpha_optimizer else None,
            entropy_temperature=entropy_temperature,
            logged_action_uniform_prior=logged_action_uniform_prior,
        )

        return SACTrainer(
            q1_network,
            actor_network,
            parameters,
            use_gpu=use_gpu,
            value_network=value_network,
            q2_network=q2_network,
        )