Exemplo n.º 1
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Attempting to solve Atari games using A2C.

@author: Ayman Jabri
"""
import torch
import ptan
from lib import data, utils, model
from tensorboardX import SummaryWriter

if __name__=='__main__':
    message = '*'*10 + '  A2C on Atari ' +'*'*10
    args = utils.argpars_dqn(message)
    params = data.params[args.env]
    utils.update_params(params, args)

    params.n_envs = max(params.n_envs, 8)

    device = 'cuda' if args.cuda else 'cpu'
    envs = utils.createEnvs(params, stack_frames=2)
    shape = envs[0].observation_space.shape
    actions = envs[0].action_space.n
    net = model.A2CNet(shape, actions)
    net.to(device)
    agent = ptan.agent.ActorCriticAgent(net, device=device, apply_softmax=True)

    exp_src = ptan.experience.ExperienceSourceFirstLast(envs, agent, params.gamma,steps_count=params.steps)
    generator = utils.BatchGenerator(exp_src, params)
    mean_monitor = utils.MeanRewardsMonitor(envs[0], net, 'A2C', params.solve_rewards)
Exemplo n.º 2
0
@author: ayman
"""

import ptan
import gym
import torch

from tensorboardX import SummaryWriter

from lib import data, utils, model, atari_wrappers

ALGORITHM = 'RAINBOW_DDQN'
GAMES = list(data.params.keys())

if __name__ == '__main__':
    args = utils.argpars_dqn()
    params = data.params[args.env]
    utils.update_params(params, args)

    device = 'cuda' if args.cuda else 'cpu'

    envs = []
    for _ in range(params.n_envs):
        env = gym.make(params.env)
        env = atari_wrappers.wrap_dqn_light(env, params.frame_stack, args.skip)
        if params.seed:
            env.seed(params.seed)
        envs.append(env)

    shape = env.observation_space.shape
    actions = env.action_space.n