def load_envs_and_config(file_path, device):
    save_dict = torch.load(file_path)
    config = save_dict['config']
    config['device'] = device
    env_factory = EnvFactory(config=config)
    virtual_env = env_factory.generate_virtual_env()
    virtual_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    return virtual_env, real_env, config
Exemple #2
0
def load_envs_and_config(dir, model_file_name):
    file_path = os.path.join(dir, model_file_name)
    save_dict = torch.load(file_path)
    config = save_dict['config']
    # config['envs']['CartPole-v0']['solved_reward'] = 195
    # config['envs']['CartPole-v0']['max_steps'] = 200
    env_factory = EnvFactory(config=config)
    virtual_env = env_factory.generate_virtual_env()
    virtual_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    return virtual_env, real_env, config
def load_envs_and_config(dir, file_name):
    file_path = os.path.join(dir, file_name)
    save_dict = torch.load(file_path)
    config = save_dict['config']

    env_factory = EnvFactory(config=config)
    virtual_env = env_factory.generate_virtual_env()
    virtual_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    grid_size, _, _ = file_name.split('_')
    M, N = grid_size.split('x')
    return virtual_env, real_env, config, int(M), int(N)
def load_envs_and_config(file_name, model_dir, device):
    file_path = os.path.join(model_dir, file_name)
    save_dict = torch.load(file_path)
    config = save_dict['config']
    config['device'] = device
    env_factory = EnvFactory(config=config)
    virtual_env = env_factory.generate_virtual_env()
    virtual_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    # load additional agent configs
    with open("../default_config_acrobot.yaml", "r") as stream:
        config_new = yaml.safe_load(stream)["agents"]

    config["agents"]["duelingddqn"] = config_new["duelingddqn"]
    config["agents"]["duelingddqn_vary"] = config_new["duelingddqn_vary"]

    return virtual_env, real_env, config
        config_mod['agents'][
            self.agent_name]['hidden_size'] = config['hidden_size']
        config_mod['agents'][
            self.agent_name]['hidden_layer'] = config['hidden_layer']

        print("full config: ", config_mod['agents'][self.agent_name])

        return config_mod


if __name__ == "__main__":
    with open("../default_config_cartpole.yaml", "r") as stream:
        config = yaml.safe_load(stream)

    torch.set_num_threads(1)

    # generate environment
    env_fac = EnvFactory(config)
    virt_env = env_fac.generate_virtual_env()
    real_env = env_fac.generate_real_env()

    timing = []
    for i in range(10):
        ddqn = DDQN_vary(env=real_env, config=config, icm=True)
        # ddqn.train(env=virt_env, time_remaining=50)
        print('TRAIN')
        ddqn.train(env=real_env, time_remaining=500)
        # print('TEST')
        # ddqn.test(env=real_env, time_remaining=500)
    print('avg. ' + str(sum(timing) / len(timing)))