예제 #1
0
def create_trainer(model_type, params, rl_parameters, use_gpu, env):
    if model_type == ModelType.PYTORCH_DISCRETE_DQN.value:
        training_parameters = params["training"]
        if isinstance(training_parameters, dict):
            training_parameters = TrainingParameters(**training_parameters)
        rainbow_parameters = params["rainbow"]
        if isinstance(rainbow_parameters, dict):
            rainbow_parameters = RainbowDQNParameters(**rainbow_parameters)
        if env.img:
            assert (
                training_parameters.cnn_parameters is not None
            ), "Missing CNN parameters for image input"
            if isinstance(training_parameters.cnn_parameters, dict):
                training_parameters.cnn_parameters = CNNParameters(
                    **training_parameters.cnn_parameters
                )
            training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels
            training_parameters.cnn_parameters.input_height = env.height
            training_parameters.cnn_parameters.input_width = env.width
            training_parameters.cnn_parameters.num_input_channels = (
                env.num_input_channels
            )
        else:
            assert (
                training_parameters.cnn_parameters is None
            ), "Extra CNN parameters for non-image input"
        trainer_params = DiscreteActionModelParameters(
            actions=env.actions,
            rl=rl_parameters,
            training=training_parameters,
            rainbow=rainbow_parameters,
        )
        trainer = create_dqn_trainer_from_params(
            trainer_params, env.normalization, use_gpu
        )

    elif model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value:
        training_parameters = params["training"]
        if isinstance(training_parameters, dict):
            training_parameters = TrainingParameters(**training_parameters)
        rainbow_parameters = params["rainbow"]
        if isinstance(rainbow_parameters, dict):
            rainbow_parameters = RainbowDQNParameters(**rainbow_parameters)
        if env.img:
            assert (
                training_parameters.cnn_parameters is not None
            ), "Missing CNN parameters for image input"
            training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels
        else:
            assert (
                training_parameters.cnn_parameters is None
            ), "Extra CNN parameters for non-image input"
        trainer_params = ContinuousActionModelParameters(
            rl=rl_parameters, training=training_parameters, rainbow=rainbow_parameters
        )
        trainer = create_parametric_dqn_trainer_from_params(
            trainer_params, env.normalization, env.normalization_action, use_gpu
        )

    elif model_type == ModelType.TD3.value:
        trainer_params = TD3ModelParameters(
            rl=rl_parameters,
            training=TD3TrainingParameters(
                minibatch_size=params["td3_training"]["minibatch_size"],
                q_network_optimizer=OptimizerParameters(
                    **params["td3_training"]["q_network_optimizer"]
                ),
                actor_network_optimizer=OptimizerParameters(
                    **params["td3_training"]["actor_network_optimizer"]
                ),
                use_2_q_functions=params["td3_training"]["use_2_q_functions"],
                exploration_noise=params["td3_training"]["exploration_noise"],
                initial_exploration_ts=params["td3_training"]["initial_exploration_ts"],
                target_policy_smoothing=params["td3_training"][
                    "target_policy_smoothing"
                ],
                noise_clip=params["td3_training"]["noise_clip"],
                delayed_policy_update=params["td3_training"]["delayed_policy_update"],
            ),
            q_network=FeedForwardParameters(**params["critic_training"]),
            actor_network=FeedForwardParameters(**params["actor_training"]),
        )
        trainer = get_td3_trainer(env, trainer_params, use_gpu)

    elif model_type == ModelType.SOFT_ACTOR_CRITIC.value:
        value_network = None
        value_network_optimizer = None
        alpha_optimizer = None
        if params["sac_training"]["use_value_network"]:
            value_network = FeedForwardParameters(**params["sac_value_training"])
            value_network_optimizer = OptimizerParameters(
                **params["sac_training"]["value_network_optimizer"]
            )
        if "alpha_optimizer" in params["sac_training"]:
            alpha_optimizer = OptimizerParameters(
                **params["sac_training"]["alpha_optimizer"]
            )
        entropy_temperature = params["sac_training"].get("entropy_temperature", None)
        target_entropy = params["sac_training"].get("target_entropy", None)

        trainer_params = SACModelParameters(
            rl=rl_parameters,
            training=SACTrainingParameters(
                minibatch_size=params["sac_training"]["minibatch_size"],
                use_2_q_functions=params["sac_training"]["use_2_q_functions"],
                use_value_network=params["sac_training"]["use_value_network"],
                q_network_optimizer=OptimizerParameters(
                    **params["sac_training"]["q_network_optimizer"]
                ),
                value_network_optimizer=value_network_optimizer,
                actor_network_optimizer=OptimizerParameters(
                    **params["sac_training"]["actor_network_optimizer"]
                ),
                entropy_temperature=entropy_temperature,
                target_entropy=target_entropy,
                alpha_optimizer=alpha_optimizer,
            ),
            q_network=FeedForwardParameters(**params["critic_training"]),
            value_network=value_network,
            actor_network=FeedForwardParameters(**params["actor_training"]),
        )
        trainer = get_sac_trainer(env, trainer_params, use_gpu)
    elif model_type == ModelType.CEM.value:
        trainer_params = CEMParameters(**params["cem"])
        trainer_params.mdnrnn = MDNRNNParameters(**params["cem"]["mdnrnn"])
        trainer_params.rl = rl_parameters
        trainer = get_cem_trainer(env, trainer_params, use_gpu)
    else:
        raise NotImplementedError("Model of type {} not supported".format(model_type))

    return trainer
예제 #2
0
파일: run_gym.py 프로젝트: odellus/ReAgent
def create_trainer(params: OpenAiGymParameters, env: OpenAIGymEnvironment):
    use_gpu = params.use_gpu
    model_type = params.model_type
    assert params.rl is not None
    rl_parameters = params.rl

    if model_type == ModelType.PYTORCH_DISCRETE_DQN.value:
        assert params.training is not None
        training_parameters = params.training
        assert params.rainbow is not None
        if env.img:
            assert (
                training_parameters.cnn_parameters is not None
            ), "Missing CNN parameters for image input"
            training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels
            training_parameters._replace(
                cnn_parameters=training_parameters.cnn_parameters._replace(
                    input_height=env.height,
                    input_width=env.width,
                    num_input_channels=env.num_input_channels,
                )
            )
        else:
            assert (
                training_parameters.cnn_parameters is None
            ), "Extra CNN parameters for non-image input"
        discrete_trainer_params = DiscreteActionModelParameters(
            actions=env.actions,
            rl=rl_parameters,
            training=training_parameters,
            rainbow=params.rainbow,
            evaluation=params.evaluation,
        )
        trainer = create_dqn_trainer_from_params(
            discrete_trainer_params, env.normalization, use_gpu
        )

    elif model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value:
        assert params.training is not None
        training_parameters = params.training
        assert params.rainbow is not None
        if env.img:
            assert (
                training_parameters.cnn_parameters is not None
            ), "Missing CNN parameters for image input"
            training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels
        else:
            assert (
                training_parameters.cnn_parameters is None
            ), "Extra CNN parameters for non-image input"
        continuous_trainer_params = ContinuousActionModelParameters(
            rl=rl_parameters, training=training_parameters, rainbow=params.rainbow
        )
        trainer = create_parametric_dqn_trainer_from_params(
            continuous_trainer_params,
            env.normalization,
            env.normalization_action,
            use_gpu,
        )

    elif model_type == ModelType.TD3.value:
        assert params.td3_training is not None
        assert params.critic_training is not None
        assert params.actor_training is not None
        td3_trainer_params = TD3ModelParameters(
            rl=rl_parameters,
            training=params.td3_training,
            q_network=params.critic_training,
            actor_network=params.actor_training,
        )
        trainer = get_td3_trainer(env, td3_trainer_params, use_gpu)

    elif model_type == ModelType.SOFT_ACTOR_CRITIC.value:
        assert params.sac_training is not None
        assert params.critic_training is not None
        assert params.actor_training is not None
        value_network = None
        if params.sac_training.use_value_network:
            value_network = params.sac_value_training

        sac_trainer_params = SACModelParameters(
            rl=rl_parameters,
            training=params.sac_training,
            q_network=params.critic_training,
            value_network=value_network,
            actor_network=params.actor_training,
        )
        trainer = get_sac_trainer(env, sac_trainer_params, use_gpu)
    elif model_type == ModelType.CEM.value:
        assert params.cem is not None
        cem_trainer_params = params.cem._replace(rl=params.rl)
        trainer = get_cem_trainer(env, cem_trainer_params, use_gpu)
    else:
        raise NotImplementedError("Model of type {} not supported".format(model_type))

    return trainer