示例#1
0
def get_player(viz=False, train=False):
    env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
                      live_lost_as_eoe=train, max_num_frames=30000)
    env = FireResetEnv(env)
    env = WarpFrame(env, IMAGE_SIZE)
    if not train:
        # in training, history is taken care of in expreplay buffer
        env = FrameStack(env, FRAME_HISTORY)
    return env
示例#2
0
def get_player(viz=False, train=False):
    pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
                     image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
    if not train:
        pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
        pl = HistoryFramePlayer(pl, FRAME_HISTORY)
        pl = PreventStuckPlayer(pl, 30, 1)
    pl = LimitLengthPlayer(pl, 30000)
    return pl
示例#3
0
def get_player(rom, viz=False, train=False):
    env = AtariPlayer(rom,
                      frame_skip=ACTION_REPEAT,
                      viz=viz,
                      live_lost_as_eoe=train,
                      max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    if not train:
        # in training, context is taken care of in expreplay buffer
        env = FrameStack(env, CONTEXT_LEN)
    return env
示例#4
0
文件: DQN.py 项目: zzuxzt/tensorpack
def get_player(viz=False, train=False):
    env = AtariPlayer(ROM_FILE,
                      frame_skip=ACTION_REPEAT,
                      viz=viz,
                      live_lost_as_eoe=train,
                      max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env,
                   lambda im: cv2.resize(im, IMAGE_SIZE)[:, :, np.newaxis])
    if not train:
        # in training, history is taken care of in expreplay buffer
        env = FrameStack(env, FRAME_HISTORY)
    return env
示例#5
0
def get_player(viz=False, train=False):
    if USE_GYM:
        env = gym.make(ENV_NAME)
    else:
        from atari import AtariPlayer
        env = AtariPlayer(ENV_NAME, frame_skip=4, viz=viz,
                          live_lost_as_eoe=train, max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: resize_keepdims(im, IMAGE_SIZE))
    if not train:
        # in training, history is taken care of in expreplay buffer
        env = FrameStack(env, FRAME_HISTORY)
    if train and USE_GYM:
        env = LimitLength(env, 60000)
    return env
示例#6
0
文件: DQN.py 项目: tranlm/tensorpack
def get_player(viz=False, train=False):
    pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
            image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
    global NUM_ACTIONS
    NUM_ACTIONS = pl.get_action_space().num_actions()
    if not train:
        pl = HistoryFramePlayer(pl, FRAME_HISTORY)
        pl = PreventStuckPlayer(pl, 30, 1)
    pl = LimitLengthPlayer(pl, 30000)
    return pl
示例#7
0
def get_player(rom,
               image_size,
               viz=False,
               train=False,
               frame_skip=1,
               context_len=1):
    env = AtariPlayer(rom,
                      frame_skip=frame_skip,
                      viz=viz,
                      live_lost_as_eoe=train,
                      max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, image_size))
    if not train:
        # in training, context is taken care of in expreplay buffer
        env = FrameStack(env, context_len)
    return env
示例#8
0
def get_player(train=False, dumpdir=None):
    use_gym = not ENV_NAME.endswith(".bin")
    if use_gym:
        env = gym.make(ENV_NAME)
    else:
        from atari import AtariPlayer
        env = AtariPlayer(ENV_NAME,
                          frame_skip=4,
                          viz=False,
                          live_lost_as_eoe=train,
                          max_num_frames=60000,
                          grayscale=False)
    if dumpdir:
        env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    env = FrameStack(env, 4)
    if train and use_gym:
        env = LimitLength(env, 60000)
    return env
示例#9
0
    parser.add_argument('--rom', help='atari rom', required=True)
    parser.add_argument('--algo',
                        help='algorithm',
                        choices=['DQN', 'Double', 'Dueling'],
                        default='Double')
    args = parser.parse_args()

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.task != 'train':
        assert args.load is not None
    ROM_FILE = args.rom
    METHOD = args.algo

    # set num_actions
    pl = AtariPlayer(ROM_FILE, viz=False)
    NUM_ACTIONS = pl.get_action_space().num_actions()
    del pl

    if args.task != 'train':
        cfg = PredictConfig(model=Model(),
                            session_init=get_model_loader(args.load),
                            input_names=['state'],
                            output_names=['Qvalue'])
        if args.task == 'play':
            play_model(cfg, get_player(viz=0.01))
        elif args.task == 'eval':
            eval_model_multithread(cfg, EVAL_EPISODE, get_player)
    else:
        config = get_config()
        if args.load:
示例#10
0
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
    parser.add_argument('--load', help='load model')
    parser.add_argument('--task', help='task to perform',
                        choices=['play', 'eval', 'train'], default='train')
    parser.add_argument('--rom', help='atari rom', required=True)
    parser.add_argument('--algo', help='algorithm',
                        choices=['DQN', 'Double', 'Dueling'], default='Double')
    args = parser.parse_args()

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    ROM_FILE = args.rom
    METHOD = args.algo
    # set num_actions
    NUM_ACTIONS = AtariPlayer(ROM_FILE).action_space.n
    logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(PredictConfig(
            model=Model(),
            session_init=get_model_loader(args.load),
            input_names=['state'],
            output_names=['Qvalue']))
        if args.task == 'play':
            play_n_episodes(get_player(viz=0.01), pred, 100)
        elif args.task == 'eval':
            eval_model_multithread(pred, EVAL_EPISODE, get_player)
    else:
        logger.set_logger_dir(
示例#11
0
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
    parser.add_argument('--load', help='load model')
    parser.add_argument('--task', help='task to perform',
                        choices=['play', 'eval', 'train'], default='train')
    parser.add_argument('--rom', help='atari rom', required=True)
    parser.add_argument('--algo', help='algorithm',
                        choices=['DQN', 'Double', 'Dueling'], default='Double')
    args = parser.parse_args()

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    ROM_FILE = args.rom
    METHOD = args.algo
    # set num_actions
    NUM_ACTIONS = AtariPlayer(ROM_FILE).get_action_space().num_actions()
    logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))

    if args.task != 'train':
        assert args.load is not None
        cfg = PredictConfig(
            model=Model(),
            session_init=get_model_loader(args.load),
            input_names=['state'],
            output_names=['Qvalue'])
        if args.task == 'play':
            play_model(cfg, get_player(viz=0.01))
        elif args.task == 'eval':
            eval_model_multithread(cfg, EVAL_EPISODE, get_player)
    else:
        logger.set_logger_dir(