num_trials = 25000 num_events = 250 relative_path_to_data = '../../Data' # ../../Data if you are in Tests/CH2 # make gym environment env = gym.make(env_name) plt.close() rep_types = { 'onehot': onehot, 'random': random, 'place_cell': place_cell, 'sr': sr, 'latent': latents } state_reps, representation_name, input_dims, _ = rep_types[ representation_type](env) # load weights to head_ac network from previously learned agent AC_head_agent = head_AC(input_dims, test_env.action_space.n, lr=0.0005) agent = Agent(AC_head_agent, state_representations=state_reps) ex = flat_expt(agent, test_env) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=test_env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, dir=relative_path_to_data, file=write_to_file)
'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } run_id = conv_ids[f'gridworld:gridworld-v{version}'] agent_path = relative_path_to_data + f'agents/saved_agents/{run_id}.pt' state_reps, representation_name, input_dims, _ = latents(env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[rep_type](env) AC_head_agent = head_AC(input_dims, env.action_space.n, lr=learning_rate) memory = Memory(entry_size=env.action_space.n, cache_limit=cache_size_for_env, distance=distance_metric) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) ex = flat_expt(agent, env) print( f"Experiment running {env.unwrapped.spec.id} \nRepresentation: {representation_name} \nCache Limit:{cache_size_for_env} \nDistance: {distance_metric}" ) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=test_env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, dir=relative_path_to_data, file=write_to_file)
distance=distance_metric) if load_mem: df = pd.read_csv(relative_path_to_data + write_to_file) df_gb = df.groupby(['env_name', 'representation'])["save_id"] id = list(df_gb.get_group((test_env_name, representation_name)))[0] print(id) with open(relative_path_to_data + f'/ec_dicts/{id}_EC.p', 'rb') as f: loaded_memory_dict = pickle.load(f) memory.cache_list = loaded_memory_dict else: id = '' agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) #run = flat_expt(agent, env) #run.run(NUM_TRIALS=num_trials, NUM_EVENTS=num_events) test_env = gym.make(test_env_name) plt.close() print(test_env.rewards) test_run = flat_expt(agent, test_env) #test_run.data = run.data test_run.run(NUM_TRIALS=num_trials * 5, NUM_EVENTS=num_events) test_run.record_log(test_env_name, representation_name, num_trials * 5, num_events, dir=relative_path_to_data, file=write_to_file, load_from=id)