def main_DQN_plus_greedy(): GREEDY_TOTAL_NUM_EPISODES = 1000 GREEDY_NUM_EPISODES = GREEDY_TOTAL_NUM_EPISODES // 3 env = AgarioEnv(render=RENDER, speed_scale=SPEED_SCALE, display_text=DISPLAY_TEXT, grid_resolution=GRID_RESOLUTION) agent = DQNAgent(height=GRID_RESOLUTION, width=GRID_RESOLUTION, input_channels=2, num_actions=ACTION_DISCRETIZATION, loadpath='') greedy = Greedy() env.seed(41) agent.seed(41) for episode in range(GREEDY_TOTAL_NUM_EPISODES): state = env.reset() done = False new_state = None raw_action, action = None, None reward = 0 num_steps = 0 is_greedy_episode = episode < GREEDY_NUM_EPISODES while not done: if is_greedy_episode: action = greedy.get_action(state) raw_action = agent.angle_to_action(action) # print(f'angle: {action}, raw_action: {raw_action}') else: raw_action = agent.get_action(state) action = agent.action_to_angle(raw_action) for _ in range(NUM_SKIP_FRAMES): if RENDER: env.render() new_state, reward, done, _ = env.step(action) num_steps += 1 # print(f'step = {num_steps}') if done or num_steps > MAX_STEPS: new_state = None done = True agent.memory.push(state, raw_action, new_state, reward) agent.optimize() if done: print( f'{"Greedy" if is_greedy_episode else "DQN" } episode done, max_mass: {state.mass}' ) if not is_greedy_episode: agent.max_masses.append(state.mass) if num_steps % agent.TARGET_UPDATE == 0: # print(f'UPDATING TARGET') agent.target_net.load_state_dict(agent.policy_net.state_dict()) state = new_state print(f'Complete') torch.save( agent.policy_net.state_dict(), f'model_GREEDY_DQN_{NUM_EPISODES}_{str(datetime.now()).replace(" ", "_")}_episodes.model' ) agent.print_final_stats() env.close()
def main_DQN(): env = AgarioEnv(render=RENDER, speed_scale=SPEED_SCALE, display_text=DISPLAY_TEXT, grid_resolution=GRID_RESOLUTION) agent = DQNAgent(height=GRID_RESOLUTION, width=GRID_RESOLUTION, input_channels=2, num_actions=ACTION_DISCRETIZATION, loadpath='') # env.seed(41) # agent.seed(41) for episode in range(NUM_EPISODES): state = env.reset() done = False new_state = None reward = 0 num_steps = 0 while not done: raw_action = agent.get_action(state) action = agent.action_to_angle(raw_action) for _ in range(NUM_SKIP_FRAMES): if RENDER: env.render() new_state, reward, done, _ = env.step(action) num_steps += 1 # print(f'step = {num_steps}') if done or num_steps > MAX_STEPS: new_state = None done = True agent.memory.push(state, raw_action, new_state, reward) agent.optimize() if done: print(f'Episode {episode} done, max_mass = {state.mass}') agent.max_masses.append(state.mass) agent.print_final_stats() if num_steps % agent.TARGET_UPDATE == 0: # print(f'UPDATING TARGET') agent.target_net.load_state_dict(agent.policy_net.state_dict()) state = new_state if episode % WEIGHTS_SAVE_EPISODE_STEP == 0: torch.save( agent.policy_net.state_dict(), f'DQN_weights/model_{episode}_{str(datetime.now()).replace(" ", "_")}_episodes.model' ) np.savetxt( f'DQN_weights/model_{episode}_{str(datetime.now()).replace(" ", "_")}_episodes.performance', np.array(agent.max_masses)) print(f'Complete') torch.save( agent.policy_net.state_dict(), f'model_{NUM_EPISODES}_{str(datetime.now()).replace(" ", "_")}_episodes.model' ) np.savetxt( f'DQN_weights/model_{NUM_EPISODES}_{str(datetime.now()).replace(" ", "_")}_episodes.performance', np.array(agent.max_masses)) agent.print_final_stats() env.close()