Beispiel #1
0
    def test_nested_action_spaces(self):
        config = DEFAULT_CONFIG.copy()
        config["env"] = RandomEnv
        # Write output to check, whether actions are written correctly.
        tmp_dir = os.popen("mktemp -d").read()[:-1]
        if not os.path.exists(tmp_dir):
            # Last resort: Resolve via underlying tempdir (and cut tmp_.
            tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
            assert os.path.exists(tmp_dir), f"'{tmp_dir}' not found!"
        config["output"] = tmp_dir
        # Switch off OPE as we don't write action-probs.
        # TODO: We should probably always write those if `output` is given.
        config["input_evaluation"] = []

        # Pretend actions in offline files are already normalized.
        config["actions_in_input_normalized"] = True

        for _ in framework_iterator(config):
            for name, action_space in SPACES.items():
                config["env_config"] = {
                    "action_space": action_space,
                }
                for flatten in [False, True]:
                    print(f"A={action_space} flatten={flatten}")
                    shutil.rmtree(config["output"])
                    config["_disable_action_flattening"] = not flatten
                    trainer = PGTrainer(config)
                    trainer.train()
                    trainer.stop()

                    # Check actions in output file (whether properly flattened
                    # or not).
                    reader = JsonReader(
                        inputs=config["output"],
                        ioctx=trainer.workers.local_worker().io_context,
                    )
                    sample_batch = reader.next()
                    if flatten:
                        assert isinstance(sample_batch["actions"], np.ndarray)
                        assert len(sample_batch["actions"].shape) == 2
                        assert sample_batch["actions"].shape[0] == len(
                            sample_batch)
                    else:
                        tree.assert_same_structure(
                            trainer.get_policy().action_space_struct,
                            sample_batch["actions"],
                        )

                    # Test, whether offline data can be properly read by a
                    # BCTrainer, configured accordingly.
                    config["input"] = config["output"]
                    del config["output"]
                    bc_trainer = BCTrainer(config=config)
                    bc_trainer.train()
                    bc_trainer.stop()
                    config["output"] = tmp_dir
                    config["input"] = "sampler"
Beispiel #2
0
MARWIL_agent = MARWILTrainer(config=marwil_config, env=SSA_Tasker_Env)
MARWIL_agent.restore(marwil_checkpoint)
MARWIL_agent.get_policy().config['explore'] = False

pg_config = PG_CONFIG.copy()
pg_config['batch_mode'] = 'complete_episodes'
pg_config['train_batch_size'] = 2000
pg_config['lr'] = 0.0001
pg_config['evaluation_interval'] = None
pg_config['postprocess_inputs'] = True
pg_config['env_config'] = env_config
pg_config['explore'] = False

PGR_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
PGR_agent.restore(pgr_checkpoint)
PGR_agent.get_policy().config['explore'] = False

PGRE_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
PGRE_agent.restore(pgre_checkpoint)
PGRE_agent.get_policy().config['explore'] = False

OLR_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
OLR_agent.restore(olr_checkpoint)
OLR_agent.get_policy().config['explore'] = False


def ppo_agent(obs, env):
    return PPO_agent.compute_action(obs)


def marwil_agent(obs, env):
    MARWIL_trainer.restore(
        '/home/ash/ray_results/MARWIL_SSA_Tasker_Env_2020-09-10_08-24-14dv96mkld/checkpoint_5000/checkpoint-5000'
    )  # 20 SSA Complete
elif env_config['rso_count'] == 40:
    MARWIL_trainer.restore(
        '/home/ash/ray_results/MARWIL_SSA_Tasker_Env_2020-09-10_08-46-52wmigl7hj/checkpoint_10000/checkpoint-10000'
    )  # 40 SSA Complete
else:
    print(str(env_config['rso_count']) + ' is not a valid number of RSOs')

with MARWIL_trainer.get_policy()._sess.graph.as_default():
    with MARWIL_trainer.get_policy()._sess.as_default():
        MARWIL_trainer.get_policy().model.base_model.save_weights(
            "/tmp/pgr/weights3.h5")

with PG_trainer.get_policy()._sess.graph.as_default():
    with PG_trainer.get_policy()._sess.as_default():
        PG_trainer.get_policy().model.base_model.load_weights(
            "/tmp/pgr/weights3.h5")

result = {'timesteps_total': 0}
num_steps_train = 5000000
best_athlete = 480
episode_len_mean = []
episode_reward_mean = []
episode_reward_max = []
episode_reward_min = []
start = datetime.datetime.now()
num_steps_trained = []
clock_time = []
training_iteration = []