Exemplo n.º 1
0
def test_sac_save_load_buffer(tmpdir, dummy_config):
    env, mock_brain, _ = mb.setup_mock_env_and_brains(
        mock.Mock(),
        False,
        False,
        num_agents=NUM_AGENTS,
        vector_action_space=VECTOR_ACTION_SPACE,
        vector_obs_space=VECTOR_OBS_SPACE,
        discrete_action_space=DISCRETE_ACTION_SPACE,
    )
    trainer_params = dummy_config
    trainer_params["summary_path"] = str(tmpdir)
    trainer_params["model_path"] = str(tmpdir)
    trainer_params["save_replay_buffer"] = True
    trainer = SACTrainer(mock_brain.brain_name, 1, trainer_params, True, False,
                         0, 0)
    policy = trainer.create_policy(mock_brain)
    trainer.add_policy(mock_brain.brain_name, policy)

    trainer.update_buffer = mb.simulate_rollout(env, trainer.policy,
                                                BUFFER_INIT_SAMPLES)
    buffer_len = trainer.update_buffer.num_experiences
    trainer.save_model(mock_brain.brain_name)

    # Wipe Trainer and try to load
    trainer2 = SACTrainer(mock_brain.brain_name, 1, trainer_params, True, True,
                          0, 0)

    policy = trainer2.create_policy(mock_brain)
    trainer2.add_policy(mock_brain.brain_name, policy)
    assert trainer2.update_buffer.num_experiences == buffer_len
Exemplo n.º 2
0
def test_trainer_update_policy(mock_env, dummy_config, use_discrete):
    env, mock_brain, _ = mb.setup_mock_env_and_brains(
        mock_env,
        use_discrete,
        False,
        num_agents=NUM_AGENTS,
        vector_action_space=VECTOR_ACTION_SPACE,
        vector_obs_space=VECTOR_OBS_SPACE,
        discrete_action_space=DISCRETE_ACTION_SPACE,
    )

    trainer_params = dummy_config
    trainer_params["use_recurrent"] = True

    trainer = PPOTrainer(mock_brain, 0, trainer_params, True, False, 0, "0",
                         False)
    # Test update with sequence length smaller than batch size
    buffer = mb.simulate_rollout(env, trainer.policy, BUFFER_INIT_SAMPLES)
    # Mock out reward signal eval
    buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
    buffer.update_buffer["extrinsic_returns"] = buffer.update_buffer["rewards"]
    buffer.update_buffer["extrinsic_value_estimates"] = buffer.update_buffer[
        "rewards"]
    trainer.training_buffer = buffer
    trainer.update_policy()
    # Make batch length a larger multiple of sequence length
    trainer.trainer_parameters["batch_size"] = 128
    trainer.update_policy()
    # Make batch length a larger non-multiple of sequence length
    trainer.trainer_parameters["batch_size"] = 100
    trainer.update_policy()
Exemplo n.º 3
0
def create_sac_policy_mock(mock_env, dummy_config, use_rnn, use_discrete, use_visual):
    env, mock_brain, _ = mb.setup_mock_env_and_brains(
        mock_env,
        use_discrete,
        use_visual,
        num_agents=NUM_AGENTS,
        vector_action_space=VECTOR_ACTION_SPACE,
        vector_obs_space=VECTOR_OBS_SPACE,
        discrete_action_space=DISCRETE_ACTION_SPACE,
    )

    trainer_parameters = dummy_config
    model_path = env.external_brain_names[0]
    trainer_parameters["model_path"] = model_path
    trainer_parameters["keep_checkpoints"] = 3
    trainer_parameters["use_recurrent"] = use_rnn
    policy = SACPolicy(0, mock_brain, trainer_parameters, False, False)
    return env, policy
Exemplo n.º 4
0
def create_policy_mock(mock_env, trainer_config, reward_signal_config, use_rnn,
                       use_discrete, use_visual):
    env, mock_brain, _ = mb.setup_mock_env_and_brains(
        mock_env,
        use_discrete,
        use_visual,
        num_agents=NUM_AGENTS,
        vector_action_space=VECTOR_ACTION_SPACE,
        vector_obs_space=VECTOR_OBS_SPACE,
        discrete_action_space=DISCRETE_ACTION_SPACE,
    )

    trainer_parameters = trainer_config
    model_path = env.external_brain_names[0]
    trainer_parameters["model_path"] = model_path
    trainer_parameters["keep_checkpoints"] = 3
    trainer_parameters["reward_signals"].update(reward_signal_config)
    trainer_parameters["use_recurrent"] = use_rnn
    if trainer_config["trainer"] == "ppo":
        policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False)
    else:
        policy = SACPolicy(0, mock_brain, trainer_parameters, False, False)
    return env, policy
def test_sac_save_load_buffer(tmpdir):
    env, mock_brain, _ = mb.setup_mock_env_and_brains(
        mock.Mock(),
        False,
        False,
        num_agents=NUM_AGENTS,
        vector_action_space=VECTOR_ACTION_SPACE,
        vector_obs_space=VECTOR_OBS_SPACE,
        discrete_action_space=DISCRETE_ACTION_SPACE,
    )
    trainer_params = dummy_config()
    trainer_params["summary_path"] = str(tmpdir)
    trainer_params["model_path"] = str(tmpdir)
    trainer_params["save_replay_buffer"] = True
    trainer = SACTrainer(mock_brain, 1, trainer_params, True, False, 0, 0)
    trainer.training_buffer = mb.simulate_rollout(env, trainer.policy,
                                                  BUFFER_INIT_SAMPLES)
    buffer_len = len(trainer.training_buffer.update_buffer["actions"])
    trainer.save_model()

    # Wipe Trainer and try to load
    trainer2 = SACTrainer(mock_brain, 1, trainer_params, True, True, 0, 0)
    assert len(trainer2.training_buffer.update_buffer["actions"]) == buffer_len