Exemple #1
0
 def test_train_tf_sac_agent_initialized_from_class(self):
     model_name = "sac_tf"
     env = gym.make("MountainCarContinuous-v0")
     train_tf_agent(sac_agent.SacAgent, env, 1700, model_name)
     model_saved = check_model_is_saved(model_name)
     shutil.rmtree(save_path)
     self.assertTrue(model_saved)
Exemple #2
0
def train(
    model: Union[BaseAlgorithm, TFAgent, Type[BaseAlgorithm], Type[TFAgent]],
    env: Union[Env, TimeLimit],
    total_timesteps: int,
    stop_threshold: int,
    model_name: Optional[str] = None,
    maximum_episode_reward: Optional[int] = None,
):
    env = env.env if isinstance(env, TimeLimit) else env
    model_name = model_name or utils.compile_random_model_name(model)
    train_kwargs = {
        "model": model,
        "env": env,
        "total_timesteps": total_timesteps,
        "model_name": model_name,
        "maximum_episode_reward": maximum_episode_reward,
        "stop_training_threshold": stop_threshold,
    }
    # Check instances:
    if isinstance(model, BaseAlgorithm):
        train_baselines_model(**train_kwargs)
    elif isinstance(model, TFAgent):
        train_tf_agent(**train_kwargs)
    # Check classes:
    elif issubclass(model, BaseAlgorithm):
        train_baselines_model(**train_kwargs)
    elif issubclass(model, TFAgent):
        train_tf_agent(**train_kwargs)
    else:
        raise ValueError(
            f"Model of class `{model.__class__.__name__}` is not supported")
Exemple #3
0
 def test_train_tf_reinforce_agent_initialized_from_class(self):
     env_name = "CartPole-v0"
     model_name = "reinforce_tf"
     env = gym.make(env_name)
     train_tf_agent(reinforce_agent.ReinforceAgent, env, 1700, model_name)
     model_saved = check_model_is_saved(model_name)
     shutil.rmtree(save_path)
     self.assertTrue(model_saved)