Exemplo n.º 1
0
def main_DQN_plus_greedy():
    GREEDY_TOTAL_NUM_EPISODES = 1000
    GREEDY_NUM_EPISODES = GREEDY_TOTAL_NUM_EPISODES // 3
    env = AgarioEnv(render=RENDER,
                    speed_scale=SPEED_SCALE,
                    display_text=DISPLAY_TEXT,
                    grid_resolution=GRID_RESOLUTION)
    agent = DQNAgent(height=GRID_RESOLUTION,
                     width=GRID_RESOLUTION,
                     input_channels=2,
                     num_actions=ACTION_DISCRETIZATION,
                     loadpath='')
    greedy = Greedy()
    env.seed(41)
    agent.seed(41)
    for episode in range(GREEDY_TOTAL_NUM_EPISODES):
        state = env.reset()
        done = False
        new_state = None
        raw_action, action = None, None
        reward = 0
        num_steps = 0
        is_greedy_episode = episode < GREEDY_NUM_EPISODES
        while not done:
            if is_greedy_episode:
                action = greedy.get_action(state)
                raw_action = agent.angle_to_action(action)
                # print(f'angle: {action}, raw_action: {raw_action}')
            else:
                raw_action = agent.get_action(state)
                action = agent.action_to_angle(raw_action)
            for _ in range(NUM_SKIP_FRAMES):
                if RENDER:
                    env.render()
                new_state, reward, done, _ = env.step(action)
            num_steps += 1
            # print(f'step = {num_steps}')
            if done or num_steps > MAX_STEPS:
                new_state = None
                done = True
            agent.memory.push(state, raw_action, new_state, reward)
            agent.optimize()
            if done:
                print(
                    f'{"Greedy" if is_greedy_episode else "DQN" } episode done, max_mass: {state.mass}'
                )
                if not is_greedy_episode:
                    agent.max_masses.append(state.mass)
            if num_steps % agent.TARGET_UPDATE == 0:
                # print(f'UPDATING TARGET')
                agent.target_net.load_state_dict(agent.policy_net.state_dict())
            state = new_state
    print(f'Complete')
    torch.save(
        agent.policy_net.state_dict(),
        f'model_GREEDY_DQN_{NUM_EPISODES}_{str(datetime.now()).replace(" ", "_")}_episodes.model'
    )
    agent.print_final_stats()
    env.close()
Exemplo n.º 2
0
def main_DQN():
    env = AgarioEnv(render=RENDER,
                    speed_scale=SPEED_SCALE,
                    display_text=DISPLAY_TEXT,
                    grid_resolution=GRID_RESOLUTION)
    agent = DQNAgent(height=GRID_RESOLUTION,
                     width=GRID_RESOLUTION,
                     input_channels=2,
                     num_actions=ACTION_DISCRETIZATION,
                     loadpath='')
    # env.seed(41)
    # agent.seed(41)
    for episode in range(NUM_EPISODES):
        state = env.reset()
        done = False
        new_state = None
        reward = 0
        num_steps = 0
        while not done:
            raw_action = agent.get_action(state)
            action = agent.action_to_angle(raw_action)
            for _ in range(NUM_SKIP_FRAMES):
                if RENDER:
                    env.render()
                new_state, reward, done, _ = env.step(action)
            num_steps += 1
            # print(f'step = {num_steps}')
            if done or num_steps > MAX_STEPS:
                new_state = None
                done = True
            agent.memory.push(state, raw_action, new_state, reward)
            agent.optimize()
            if done:
                print(f'Episode {episode} done, max_mass = {state.mass}')
                agent.max_masses.append(state.mass)
                agent.print_final_stats()
            if num_steps % agent.TARGET_UPDATE == 0:
                # print(f'UPDATING TARGET')
                agent.target_net.load_state_dict(agent.policy_net.state_dict())
            state = new_state
        if episode % WEIGHTS_SAVE_EPISODE_STEP == 0:
            torch.save(
                agent.policy_net.state_dict(),
                f'DQN_weights/model_{episode}_{str(datetime.now()).replace(" ", "_")}_episodes.model'
            )
            np.savetxt(
                f'DQN_weights/model_{episode}_{str(datetime.now()).replace(" ", "_")}_episodes.performance',
                np.array(agent.max_masses))
    print(f'Complete')
    torch.save(
        agent.policy_net.state_dict(),
        f'model_{NUM_EPISODES}_{str(datetime.now()).replace(" ", "_")}_episodes.model'
    )
    np.savetxt(
        f'DQN_weights/model_{NUM_EPISODES}_{str(datetime.now()).replace(" ", "_")}_episodes.performance',
        np.array(agent.max_masses))
    agent.print_final_stats()
    env.close()