def load_envs_and_config(file_path, device): save_dict = torch.load(file_path) config = save_dict['config'] config['device'] = device env_factory = EnvFactory(config=config) virtual_env = env_factory.generate_virtual_env() virtual_env.load_state_dict(save_dict['model']) real_env = env_factory.generate_real_env() return virtual_env, real_env, config
def load_envs_and_config(dir, model_file_name): file_path = os.path.join(dir, model_file_name) save_dict = torch.load(file_path) config = save_dict['config'] # config['envs']['CartPole-v0']['solved_reward'] = 195 # config['envs']['CartPole-v0']['max_steps'] = 200 env_factory = EnvFactory(config=config) virtual_env = env_factory.generate_virtual_env() virtual_env.load_state_dict(save_dict['model']) real_env = env_factory.generate_real_env() return virtual_env, real_env, config
def load_envs_and_config(dir, file_name): file_path = os.path.join(dir, file_name) save_dict = torch.load(file_path) config = save_dict['config'] env_factory = EnvFactory(config=config) virtual_env = env_factory.generate_virtual_env() virtual_env.load_state_dict(save_dict['model']) real_env = env_factory.generate_real_env() grid_size, _, _ = file_name.split('_') M, N = grid_size.split('x') return virtual_env, real_env, config, int(M), int(N)
def load_envs_and_config(file_name, model_dir, device): file_path = os.path.join(model_dir, file_name) save_dict = torch.load(file_path) config = save_dict['config'] config['device'] = device env_factory = EnvFactory(config=config) virtual_env = env_factory.generate_virtual_env() virtual_env.load_state_dict(save_dict['model']) real_env = env_factory.generate_real_env() # load additional agent configs with open("../default_config_acrobot.yaml", "r") as stream: config_new = yaml.safe_load(stream)["agents"] config["agents"]["duelingddqn"] = config_new["duelingddqn"] config["agents"]["duelingddqn_vary"] = config_new["duelingddqn_vary"] return virtual_env, real_env, config
config_mod['agents'][ self.agent_name]['hidden_size'] = config['hidden_size'] config_mod['agents'][ self.agent_name]['hidden_layer'] = config['hidden_layer'] print("full config: ", config_mod['agents'][self.agent_name]) return config_mod if __name__ == "__main__": with open("../default_config_cartpole.yaml", "r") as stream: config = yaml.safe_load(stream) torch.set_num_threads(1) # generate environment env_fac = EnvFactory(config) virt_env = env_fac.generate_virtual_env() real_env = env_fac.generate_real_env() timing = [] for i in range(10): ddqn = DDQN_vary(env=real_env, config=config, icm=True) # ddqn.train(env=virt_env, time_remaining=50) print('TRAIN') ddqn.train(env=real_env, time_remaining=500) # print('TEST') # ddqn.test(env=real_env, time_remaining=500) print('avg. ' + str(sum(timing) / len(timing)))