def create_trainer(model_type, params, rl_parameters, use_gpu, env): if model_type == ModelType.PYTORCH_DISCRETE_DQN.value: training_parameters = params["training"] if isinstance(training_parameters, dict): training_parameters = TrainingParameters(**training_parameters) rainbow_parameters = params["rainbow"] if isinstance(rainbow_parameters, dict): rainbow_parameters = RainbowDQNParameters(**rainbow_parameters) if env.img: assert ( training_parameters.cnn_parameters is not None ), "Missing CNN parameters for image input" if isinstance(training_parameters.cnn_parameters, dict): training_parameters.cnn_parameters = CNNParameters( **training_parameters.cnn_parameters ) training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels training_parameters.cnn_parameters.input_height = env.height training_parameters.cnn_parameters.input_width = env.width training_parameters.cnn_parameters.num_input_channels = ( env.num_input_channels ) else: assert ( training_parameters.cnn_parameters is None ), "Extra CNN parameters for non-image input" trainer_params = DiscreteActionModelParameters( actions=env.actions, rl=rl_parameters, training=training_parameters, rainbow=rainbow_parameters, ) trainer = create_dqn_trainer_from_params( trainer_params, env.normalization, use_gpu ) elif model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value: training_parameters = params["training"] if isinstance(training_parameters, dict): training_parameters = TrainingParameters(**training_parameters) rainbow_parameters = params["rainbow"] if isinstance(rainbow_parameters, dict): rainbow_parameters = RainbowDQNParameters(**rainbow_parameters) if env.img: assert ( training_parameters.cnn_parameters is not None ), "Missing CNN parameters for image input" training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels else: assert ( training_parameters.cnn_parameters is None ), "Extra CNN parameters for non-image input" trainer_params = ContinuousActionModelParameters( rl=rl_parameters, training=training_parameters, rainbow=rainbow_parameters ) trainer = create_parametric_dqn_trainer_from_params( trainer_params, env.normalization, env.normalization_action, use_gpu ) elif model_type == ModelType.TD3.value: trainer_params = TD3ModelParameters( rl=rl_parameters, training=TD3TrainingParameters( minibatch_size=params["td3_training"]["minibatch_size"], q_network_optimizer=OptimizerParameters( **params["td3_training"]["q_network_optimizer"] ), actor_network_optimizer=OptimizerParameters( **params["td3_training"]["actor_network_optimizer"] ), use_2_q_functions=params["td3_training"]["use_2_q_functions"], exploration_noise=params["td3_training"]["exploration_noise"], initial_exploration_ts=params["td3_training"]["initial_exploration_ts"], target_policy_smoothing=params["td3_training"][ "target_policy_smoothing" ], noise_clip=params["td3_training"]["noise_clip"], delayed_policy_update=params["td3_training"]["delayed_policy_update"], ), q_network=FeedForwardParameters(**params["critic_training"]), actor_network=FeedForwardParameters(**params["actor_training"]), ) trainer = get_td3_trainer(env, trainer_params, use_gpu) elif model_type == ModelType.SOFT_ACTOR_CRITIC.value: value_network = None value_network_optimizer = None alpha_optimizer = None if params["sac_training"]["use_value_network"]: value_network = FeedForwardParameters(**params["sac_value_training"]) value_network_optimizer = OptimizerParameters( **params["sac_training"]["value_network_optimizer"] ) if "alpha_optimizer" in params["sac_training"]: alpha_optimizer = OptimizerParameters( **params["sac_training"]["alpha_optimizer"] ) entropy_temperature = params["sac_training"].get("entropy_temperature", None) target_entropy = params["sac_training"].get("target_entropy", None) trainer_params = SACModelParameters( rl=rl_parameters, training=SACTrainingParameters( minibatch_size=params["sac_training"]["minibatch_size"], use_2_q_functions=params["sac_training"]["use_2_q_functions"], use_value_network=params["sac_training"]["use_value_network"], q_network_optimizer=OptimizerParameters( **params["sac_training"]["q_network_optimizer"] ), value_network_optimizer=value_network_optimizer, actor_network_optimizer=OptimizerParameters( **params["sac_training"]["actor_network_optimizer"] ), entropy_temperature=entropy_temperature, target_entropy=target_entropy, alpha_optimizer=alpha_optimizer, ), q_network=FeedForwardParameters(**params["critic_training"]), value_network=value_network, actor_network=FeedForwardParameters(**params["actor_training"]), ) trainer = get_sac_trainer(env, trainer_params, use_gpu) elif model_type == ModelType.CEM.value: trainer_params = CEMParameters(**params["cem"]) trainer_params.mdnrnn = MDNRNNParameters(**params["cem"]["mdnrnn"]) trainer_params.rl = rl_parameters trainer = get_cem_trainer(env, trainer_params, use_gpu) else: raise NotImplementedError("Model of type {} not supported".format(model_type)) return trainer
def create_trainer(params: OpenAiGymParameters, env: OpenAIGymEnvironment): use_gpu = params.use_gpu model_type = params.model_type assert params.rl is not None rl_parameters = params.rl if model_type == ModelType.PYTORCH_DISCRETE_DQN.value: assert params.training is not None training_parameters = params.training assert params.rainbow is not None if env.img: assert ( training_parameters.cnn_parameters is not None ), "Missing CNN parameters for image input" training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels training_parameters._replace( cnn_parameters=training_parameters.cnn_parameters._replace( input_height=env.height, input_width=env.width, num_input_channels=env.num_input_channels, ) ) else: assert ( training_parameters.cnn_parameters is None ), "Extra CNN parameters for non-image input" discrete_trainer_params = DiscreteActionModelParameters( actions=env.actions, rl=rl_parameters, training=training_parameters, rainbow=params.rainbow, evaluation=params.evaluation, ) trainer = create_dqn_trainer_from_params( discrete_trainer_params, env.normalization, use_gpu ) elif model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value: assert params.training is not None training_parameters = params.training assert params.rainbow is not None if env.img: assert ( training_parameters.cnn_parameters is not None ), "Missing CNN parameters for image input" training_parameters.cnn_parameters.conv_dims[0] = env.num_input_channels else: assert ( training_parameters.cnn_parameters is None ), "Extra CNN parameters for non-image input" continuous_trainer_params = ContinuousActionModelParameters( rl=rl_parameters, training=training_parameters, rainbow=params.rainbow ) trainer = create_parametric_dqn_trainer_from_params( continuous_trainer_params, env.normalization, env.normalization_action, use_gpu, ) elif model_type == ModelType.TD3.value: assert params.td3_training is not None assert params.critic_training is not None assert params.actor_training is not None td3_trainer_params = TD3ModelParameters( rl=rl_parameters, training=params.td3_training, q_network=params.critic_training, actor_network=params.actor_training, ) trainer = get_td3_trainer(env, td3_trainer_params, use_gpu) elif model_type == ModelType.SOFT_ACTOR_CRITIC.value: assert params.sac_training is not None assert params.critic_training is not None assert params.actor_training is not None value_network = None if params.sac_training.use_value_network: value_network = params.sac_value_training sac_trainer_params = SACModelParameters( rl=rl_parameters, training=params.sac_training, q_network=params.critic_training, value_network=value_network, actor_network=params.actor_training, ) trainer = get_sac_trainer(env, sac_trainer_params, use_gpu) elif model_type == ModelType.CEM.value: assert params.cem is not None cem_trainer_params = params.cem._replace(rl=params.rl) trainer = get_cem_trainer(env, cem_trainer_params, use_gpu) else: raise NotImplementedError("Model of type {} not supported".format(model_type)) return trainer