def train_agent(layout: str, episodes: int = 10000, frames_to_skip: int = 4): GAMMA = 0.99 EPSILON = 1.0 EPS_END = 0.1 EPS_DECAY = 1e7 TARGET_UPDATE = 10 BATCH_SIZE = 64 epsilon_by_frame = lambda frame_idx: EPS_END + ( EPSILON - EPS_END) * math.exp(-1. * frame_idx / EPS_DECAY) # Get screen size so that we can initialize layers correctly based on shape # returned from AI gym. Typical dimensions at this point are close to 3x40x90 # which is the result of a clamped and down-scaled render buffer in get_screen() env = PacmanEnv(layout=layout) env = SkipFrame(env, skip=frames_to_skip) env = GrayScaleObservation(env) env = ResizeObservation(env, shape=84) env = FrameStack(env, num_stack=4) screen = env.reset(mode='rgb_array') # Get number of actions from gym action space n_actions = env.action_space.n policy_net = DQN(screen.shape, n_actions).to(device) target_net = DQN(screen.shape, n_actions).to(device) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() optimizer = optim.RMSprop(policy_net.parameters()) memory = ReplayBuffer(BATCH_SIZE) for i_episode in range(episodes): # Initialize the environment and state state = env.reset(mode='rgb_array') ep_reward = 0. EPSILON = epsilon_by_frame(i_episode) for t in count(): # Select and perform an action env.render(mode='human') action = select_action(state, EPSILON, policy_net, n_actions) next_state, reward, done, info = env.step(action) reward = max(-1.0, min(reward, 1.0)) ep_reward += reward memory.cache(state, next_state, action, reward, done) # Observe new state if done: next_state = None # Move to the next state state = next_state # Perform one step of the optimization (on the target network) optimize_model(memory, policy_net, optimizer, target_net, GAMMA) if done: print("Episode #{}, lasts for {} timestep, total reward: {}". format(i_episode, t + 1, ep_reward)) break # Update the target network, copying all weights and biases in DQN if i_episode % TARGET_UPDATE == 0: target_net.load_state_dict(policy_net.state_dict()) if i_episode % 1000 == 0: save_model(target_net, 'pacman.pth') print('Complete') env.render() env.close() save_model(target_net, 'pacman.pth')