Exemplo n.º 1
0
def make_env(params):
    env = atari_wrappers.make_atari(params['env_name'], fsa=params['fsa'])
    env = atari_wrappers.wrap_deepmind(env,
                                       frame_stack=True,
                                       pytorch_img=True,
                                       fsa=params['fsa'])
    return env
Exemplo n.º 2
0
def play_func(params, net, cuda, exp_queue):
    env = atari_wrappers.make_atari(params.env_name, skip_noop=True,
                                    skip_maxskip=True)
    env = atari_wrappers.wrap_deepmind(env, pytorch_img=True,
                                       frame_stack=True,
                                       frame_stack_count=2)
    env.seed(common.SEED)
    device = torch.device("cuda" if cuda else "cpu")

    selector = ptan.actions.EpsilonGreedyActionSelector(
        epsilon=params.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(selector, params)
    agent = ptan.agent.DQNAgent(net, selector, device=device)
    exp_source = ptan.experience.ExperienceSourceFirstLast(
        env, agent, gamma=params.gamma)

    for frame_idx, exp in enumerate(exp_source):
        epsilon_tracker.frame(frame_idx/BATCH_MUL)
        exp_queue.put(exp)
        for reward, steps in exp_source.pop_rewards_steps():
            exp_queue.put(EpisodeEnded(reward, steps, selector.epsilon))
Exemplo n.º 3
0
if __name__ == "__main__":
    # get rid of missing metrics warning
    warnings.simplefilter("ignore", category=UserWarning)

    mp.set_start_method('spawn')
    params = common.HYPERPARAMS['pong']
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda",
                        default=False,
                        action="store_true",
                        help="Enable cuda")
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")

    env = atari_wrappers.make_atari(params.env_name,
                                    skip_noop=True,
                                    skip_maxskip=True)
    env = atari_wrappers.wrap_deepmind(env,
                                       pytorch_img=True,
                                       frame_stack=True,
                                       frame_stack_count=2)

    net = dqn_model.DQN(env.observation_space.shape,
                        env.action_space.n).to(device)
    tgt_net = ptan.agent.TargetNet(net)

    buffer = ptan.experience.ExperienceReplayBuffer(
        experience_source=None, buffer_size=params.replay_size)
    optimizer = optim.Adam(net.parameters(), lr=params.learning_rate)

    # start subprocess and experience queue