示例#1
0
                        choices=['DQN', 'Double', 'Dueling'],
                        default='Double')
    args = parser.parse_args()

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.task != 'train':
        assert args.load is not None
    ROM_FILE = args.rom
    METHOD = args.algo

    # set num_actions
    pl = AtariPlayer(ROM_FILE, viz=False)
    NUM_ACTIONS = pl.get_action_space().num_actions()
    del pl

    if args.task != 'train':
        cfg = PredictConfig(model=Model(),
                            session_init=get_model_loader(args.load),
                            input_names=['state'],
                            output_names=['Qvalue'])
        if args.task == 'play':
            play_model(cfg, get_player(viz=0.01))
        elif args.task == 'eval':
            eval_model_multithread(cfg, EVAL_EPISODE, get_player)
    else:
        config = get_config()
        if args.load:
            config.session_init = SaverRestore(args.load)
        QueueInputTrainer(config).train()
示例#2
0
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.task != 'train':
        assert args.load is not None

    if args.task != 'train':
        cfg = PredictConfig(
                model=Model(),
                session_init=SaverRestore(args.load),
                input_var_names=['state'],
                output_var_names=['logits:0'])
        if args.task == 'play':
            play_model(cfg)
        elif args.task == 'eval':
            eval_model_multithread(cfg, EVAL_EPISODE)
    else:
        nr_gpu = get_nr_gpu()
        if nr_gpu > 1:
            predict_tower = range(nr_gpu)[-nr_gpu/2:]
        else:
            predict_tower = [0]
        PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
        config = get_config()
        if args.load:
            config.session_init = SaverRestore(args.load)
        config.tower = range(nr_gpu)[:-nr_gpu/2] or [0]
        logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
            ','.join(map(str, config.tower)), ','.join(map(str, predict_tower))))
        AsyncMultiGPUTrainer(config, predict_tower=predict_tower).train()
示例#3
0
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.task != 'train':
        assert args.load is not None

    if args.task != 'train':
        cfg = PredictConfig(
            model=Model(),
            session_init=SaverRestore(args.load),
            input_names=['state'],
            output_names=['policy'])
        if args.task == 'play':
            play_model(cfg)
        elif args.task == 'eval':
            eval_model_multithread(cfg, args.episode)
        elif args.task == 'gen_submit':
            run_submission(cfg, args.output, args.episode)
    else:
        nr_gpu = get_nr_gpu()
        if nr_gpu > 0:
            if nr_gpu > 1:
                predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
            else:
                predict_tower = [0]
            PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
            train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
            logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
                ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
            trainer = AsyncMultiGPUTrainer
        else:
示例#4
0
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.task != 'train':
        assert args.load is not None

    if args.task != 'train':
        cfg = PredictConfig(
            model=Model(),
            session_init=SaverRestore(args.load),
            input_names=['state'],
            output_names=['policy'])
        if args.task == 'play':
            play_model(cfg, get_player(viz=0.01))
        elif args.task == 'eval':
            eval_model_multithread(cfg, args.episode, get_player)
        elif args.task == 'gen_submit':
            play_n_episodes(
                get_player(train=False, dumpdir=args.output),
                OfflinePredictor(cfg), args.episode)
            # gym.upload(output, api_key='xxx')
    else:
        nr_gpu = get_nr_gpu()
        if nr_gpu > 0:
            if nr_gpu > 1:
                predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
            else:
                predict_tower = [0]
            PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
            train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
            logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
示例#5
0
    ENV_NAME = args.env
    USE_GYM = not ENV_NAME.endswith('.bin')

    # set num_actions
    num_actions = get_player().action_space.n
    logger.info("ENV: {}, Num Actions: {}".format(args.env, num_actions))

    state_shape = IMAGE_SIZE + (3, ) if USE_GYM else IMAGE_SIZE
    model = Model(state_shape, FRAME_HISTORY, args.algo, num_actions)

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(
            PredictConfig(model=model,
                          session_init=SmartInit(args.load),
                          input_names=['state'],
                          output_names=['Qvalue']))
        if args.task == 'play':
            play_n_episodes(get_player(viz=0.01), pred, 100, render=True)
        elif args.task == 'eval':
            eval_model_multithread(pred, args.num_eval, get_player)
    else:
        logger.set_logger_dir(
            os.path.join(
                'train_log',
                'DQN-{}'.format(os.path.basename(args.env).split('.')[0])))
        config = get_config(model)
        config.session_init = SmartInit(args.load)
        launch_train_with_config(config, SimpleTrainer())