directory = '../../Data/' # ../../Data if you are in Tests/CH2 env_name = f'gridworld:gridworld-v{version}' # make gym environment env = gym.make(env_name) plt.close() print(env.rewards) num_trials = 10 num_events = 250 # load in trained network net, state_dict = load_network(version, rep_type, directory) # load weights to head_ac network from previously learned agent AC_head_agent = nets.flat_ActorCritic(400, output_dims=env.action_space.n, lr=learning_rate) top_layer_dict = {} top_layer_dict['pol.weight'] = state_dict['output.0.weight'] top_layer_dict['pol.bias'] = state_dict['output.0.bias'] top_layer_dict['val.weight'] = state_dict['output.1.weight'] top_layer_dict['val.bias'] = state_dict['output.1.bias'] AC_head_agent.load_state_dict(top_layer_dict) # get state inputs h0, h1 = get_net_activity(env_name, rep_type, net) state_reps, representation_name, = h1, f'h1_{rep_type}_latents' memory = None #Memory.EpisodicMemory(cache_limit=cache_size_for_env, entry_size=env.action_space.n) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps)
env = gym.make(env_name) plt.close() cache_size_for_env = int(len(env.useable) * (cache_size / 100)) print(env.rewards) rep_types = { 'onehot': onehot, 'random': random, 'place_cell': place_cell, 'sr': sr, 'latent': latents } state_reps, representation_name, input_dims, _ = rep_types[rep_type](env) # load weights to head_ac network from previously learned agent AC_head_agent = nets.flat_ActorCritic(input_dims, env.action_space.n, lr=learning_rate) if load_from != None: AC_head_agent.load_state_dict( torch.load(directory + f'agents/{load_from}.pt')) print(f"weights loaded from {load_from}") memory = Memory.EpisodicMemory(cache_limit=cache_size_for_env, entry_size=env.action_space.n) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) run = expt(agent, env) run.run(NUM_TRIALS=num_trials, NUM_EVENTS=num_events) run.record_log(env_name,