コード例 #1
0
directory = '../../Data/'  # ../../Data if you are in Tests/CH2
env_name = f'gridworld:gridworld-v{version}'
# make gym environment
env = gym.make(env_name)
plt.close()
print(env.rewards)

num_trials = 10
num_events = 250

# load in trained network
net, state_dict = load_network(version, rep_type, directory)

# load weights to head_ac network from previously learned agent
AC_head_agent = nets.flat_ActorCritic(400,
                                      output_dims=env.action_space.n,
                                      lr=learning_rate)
top_layer_dict = {}
top_layer_dict['pol.weight'] = state_dict['output.0.weight']
top_layer_dict['pol.bias'] = state_dict['output.0.bias']
top_layer_dict['val.weight'] = state_dict['output.1.weight']
top_layer_dict['val.bias'] = state_dict['output.1.bias']
AC_head_agent.load_state_dict(top_layer_dict)

# get state inputs
h0, h1 = get_net_activity(env_name, rep_type, net)
state_reps, representation_name, = h1, f'h1_{rep_type}_latents'

memory = None  #Memory.EpisodicMemory(cache_limit=cache_size_for_env, entry_size=env.action_space.n)

agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps)
コード例 #2
0
env = gym.make(env_name)
plt.close()
cache_size_for_env = int(len(env.useable) * (cache_size / 100))
print(env.rewards)
rep_types = {
    'onehot': onehot,
    'random': random,
    'place_cell': place_cell,
    'sr': sr,
    'latent': latents
}
state_reps, representation_name, input_dims, _ = rep_types[rep_type](env)

# load weights to head_ac network from previously learned agent
AC_head_agent = nets.flat_ActorCritic(input_dims,
                                      env.action_space.n,
                                      lr=learning_rate)

if load_from != None:
    AC_head_agent.load_state_dict(
        torch.load(directory + f'agents/{load_from}.pt'))
    print(f"weights loaded from {load_from}")

memory = Memory.EpisodicMemory(cache_limit=cache_size_for_env,
                               entry_size=env.action_space.n)

agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps)

run = expt(agent, env)
run.run(NUM_TRIALS=num_trials, NUM_EVENTS=num_events)
run.record_log(env_name,