コード例 #1
0
 def _thunk():
     env = make_atari(env_id)
     env.seed(seed + rank)
     env = Monitor(
         env,
         logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
     return wrap_deepmind(env, **wrapper_kwargs)
コード例 #2
0
def make_robotics_env(env_id, seed, rank=0):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
    env = Monitor(env,
                  logger.get_dir()
                  and os.path.join(logger.get_dir(), str(rank)),
                  info_keywords=('is_success', ))
    env.seed(seed)
    return env
コード例 #3
0
def make_mujoco_env(env_id, seed):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    env = Monitor(env, logger.get_dir())
    env.seed(seed)
    return env
コード例 #4
0
def make_mujoco_env(env_id, seed):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    rank = MPI.COMM_WORLD.Get_rank()
    set_global_seeds(seed + 10000 * rank)
    env = gym.make(env_id)
    logger.configure()
    env = Monitor(env, os.path.join(logger.get_dir(), str(rank)))
    env.seed(seed)
    return env
コード例 #5
0
ファイル: run_atari.py プロジェクト: SFU-MARS/SL_optCtrl
def train(env_id, num_timesteps, seed):
    from baselines.ppo1 import pposgd_simple, cnn_policy
    import ppo1.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)

    def policy_fn(name, ob_space, ac_space):  #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space)

    env = bench.Monitor(
        env,
        logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    pposgd_simple.learn(env,
                        policy_fn,
                        max_timesteps=int(num_timesteps * 1.1),
                        timesteps_per_actorbatch=256,
                        clip_param=0.2,
                        entcoeff=0.01,
                        optim_epochs=4,
                        optim_stepsize=1e-3,
                        optim_batchsize=64,
                        gamma=0.99,
                        lam=0.95,
                        schedule='linear')
    env.close()
コード例 #6
0
ファイル: run_deepq.py プロジェクト: kirk86/baselines
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = Monitor(env, logger.get_dir())
    env = wrap_deepmind(env)
    model = cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    fit(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )

    env.close()
    sess = tf.get_default_session()
    del sess
コード例 #7
0
ファイル: run_deepq.py プロジェクト: kirk86/baselines
    def save(self, path=None):
        """Save model to a pickle located at `path`"""
        if path is None:
            path = os.path.join(logger.get_dir(), "model.pkl")

        with tempfile.TemporaryDirectory() as td:
            self.save_state(os.path.join(td, "model"))
            arc_name = os.path.join(td, "packed.zip")
            with zipfile.ZipFile(arc_name, 'w') as zipf:
                for root, dirs, files in os.walk(td):
                    for fname in files:
                        file_path = os.path.join(root, fname)
                        if file_path != arc_name:
                            zipf.write(file_path, os.path.relpath(
                                file_path, td))
            with open(arc_name, "rb") as f:
                model_data = f.read()

        with open(path, "wb") as f:
            cloudpickle.dump((model_data, self._act_params), f)
コード例 #8
0
ファイル: main_samples.py プロジェクト: sff1019/f-IRL
    # logs
    exp_id = f"logs/{env_name}/exp-{num_expert_trajs}/{v['obj']}"  # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/' + now.strftime('%Y_%m_%d_%H_%M_%S')
    logger.configure(dir=log_folder)
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp baselines/main_samples.py {log_folder}')
    os.system(f'cp baselines/adv_smm.py {log_folder}')
    os.system(f'cp {sys.argv[1]} {log_folder}/variant_{pid}.yml')
    print('pid', os.getpid())
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    # os.makedirs(os.path.join(log_folder, 'plt'))
    os.makedirs(os.path.join(log_folder, 'model'))

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # load expert samples from trained policy
    if v['obj'] != 'airl':
        load_path = f'expert_data/states/{env_name}.pt'
コード例 #9
0
 def make_env():
     env = make_mujoco_env(args.env, args.seed)
     # env = gym.make(env_id)
     env = Monitor(env, logger.get_dir(), allow_early_resets=True)
     return env
コード例 #10
0
def fit(
        policy,
        env,
        nsteps,
        total_timesteps,
        ent_coef,
        lr,
        vf_coef=0.5,
        max_grad_norm=0.5,
        gamma=0.99,
        lam=0.95,
        log_interval=10,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        save_interval=0,
        load_path=None
):

    if isinstance(lr, float):
        lr = constfn(lr)
    else:
        assert callable(lr)
    if isinstance(cliprange, float):
        cliprange = constfn(cliprange)
    else:
        assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    # nenvs = 8
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    model = PPO2(
        policy=policy,
        observation_space=ob_space,
        action_space=ac_space,
        nbatch_act=nenvs,
        nbatch_train=nbatch_train,
        nsteps=nsteps,
        ent_coef=ent_coef,
        vf_coef=vf_coef,
        max_grad_norm=max_grad_norm
    )
    Agent().init_vars()
    # if save_interval and logger.get_dir():
    #     import cloudpickle
    #     with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
    #         fh.write(cloudpickle.dumps(make_model))
    # model = make_model()
    # if load_path is not None:
    #     model.load(load_path)
    runner = Environment(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run()
        epinfobuf.extend(epinfos)
        mblossvals = []
        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (
                        arr[mbinds] for arr in (obs,
                                                returns,
                                                masks,
                                                actions,
                                                values,
                                                neglogpacs)
                    )
                    mblossvals.append(model.predict(lrnow, cliprangenow, *slices))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (
                        arr[mbflatinds] for arr in (obs,
                                                    returns,
                                                    masks,
                                                    actions,
                                                    values,
                                                    neglogpacs)
                    )
                    mbstates = states[mbenvinds]
                    mblossvals.append(model.predict(lrnow, cliprangenow, *slices, mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            logger.logkv("serial_timesteps", update*nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
            checkdir = os.path.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = os.path.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    env.close()
    return model
コード例 #11
0
ファイル: ppo_map.py プロジェクト: yldang/exp4nav
    def train(self):
        epinfobuf = deque(maxlen=10)
        t_trainstart = time.time()
        for iter in range(self.global_iter, self.max_iters):
            self.global_iter = iter
            t_iterstart = time.time()
            if iter % self.config['save_interval'] == 0 and logger.get_dir():
                with torch.no_grad():
                    res = self.rollout(val=True)
                    obs, actions, returns, values, advs, log_probs, epinfos = res
                avg_reward = safemean([epinfo['reward'] for epinfo in epinfos])
                avg_area = safemean(
                    [epinfo['seen_area'] for epinfo in epinfos])
                logger.logkv("iter", iter)
                logger.logkv("test/total_timesteps", iter * self.nbatch)
                logger.logkv('test/avg_area', avg_area)
                logger.logkv('test/avg_reward', avg_reward)
                logger.dumpkvs()
                if avg_reward > self.best_avg_reward:
                    self.best_avg_reward = avg_reward
                    is_best = True
                else:
                    is_best = False
                self.save_model(is_best=is_best, step=iter)
            with torch.no_grad():
                obs, actions, returns, values, advs, log_probs, epinfos = self.n_rollout(
                    repeat_num=self.config['train_rollout_repeat'])
            if epinfos:
                epinfobuf.extend(epinfos)
            lossvals = {
                'policy_loss': [],
                'value_loss': [],
                'policy_entropy': [],
                'approxkl': [],
                'clipfrac': []
            }
            opt_start_t = time.time()
            noptepochs = self.noptepochs
            for _ in range(noptepochs):
                num_batches = int(
                    np.ceil(actions.shape[1] / self.config['batch_size']))
                for x in range(num_batches):
                    b_start = x * self.config['batch_size']
                    b_end = min(b_start + self.config['batch_size'],
                                actions.shape[1])
                    if self.config['use_rgb_with_map']:
                        rgbs, large_maps, small_maps = obs
                        b_rgbs, b_large_maps, b_small_maps = map(
                            lambda p: p[:, b_start:b_end],
                            (rgbs, large_maps, small_maps))
                    else:
                        large_maps, small_maps = obs
                        b_large_maps, b_small_maps = map(
                            lambda p: p[:, b_start:b_end],
                            (large_maps, small_maps))
                    b_actions, b_returns, b_advs, b_log_probs = map(
                        lambda p: p[:, b_start:b_end],
                        (actions, returns, advs, log_probs))
                    hidden_state = self.net_model.init_hidden(
                        batch_size=b_end - b_start)
                    for start in range(0, actions.shape[0], self.rnn_seq_len):
                        end = start + self.rnn_seq_len
                        slices = (arr[start:end]
                                  for arr in (b_large_maps, b_small_maps,
                                              b_actions, b_returns, b_advs,
                                              b_log_probs))

                        if self.config['use_rgb_with_map']:
                            info, hidden_state = self.net_train(
                                *slices,
                                hidden_state=hidden_state,
                                rgbs=b_rgbs[start:end])
                        else:
                            info, hidden_state = self.net_train(
                                *slices, hidden_state=hidden_state)
                        lossvals['policy_loss'].append(info['pg_loss'])
                        lossvals['value_loss'].append(info['vf_loss'])
                        lossvals['policy_entropy'].append(info['entropy'])
                        lossvals['approxkl'].append(info['approxkl'])
                        lossvals['clipfrac'].append(info['clipfrac'])
            tnow = time.time()
            int_t_per_epo = (tnow - opt_start_t) / float(self.noptepochs)
            print_cyan(
                'Net training time per epoch: {0:.4f}s'.format(int_t_per_epo))
            fps = int(self.nbatch / (tnow - t_iterstart))
            if iter % self.config['log_interval'] == 0:
                logger.logkv("Learning rate",
                             self.optimizer.param_groups[0]['lr'])
                logger.logkv("per_env_timesteps", iter * self.num_steps)
                logger.logkv("iter", iter)
                logger.logkv("total_timesteps", iter * self.nbatch)
                logger.logkv("fps", fps)
                logger.logkv(
                    'ep_rew_mean',
                    safemean([epinfo['reward'] for epinfo in epinfobuf]))
                logger.logkv(
                    'ep_area_mean',
                    safemean([epinfo['seen_area'] for epinfo in epinfobuf]))
                logger.logkv('time_elapsed', tnow - t_trainstart)
                for name, value in lossvals.items():
                    logger.logkv(name, np.mean(value))
                logger.dumpkvs()
コード例 #12
0
def main():
    parser = arg_parser()
    parser.add_argument('--platform',
                        help='environment choice',
                        choices=['atari', 'mujoco', 'humanoid', 'robotics'],
                        default='atari')
    platform_args, environ_args = parser.parse_known_args()
    platform = platform_args.platform

    # atari
    if platform == 'atari':
        args = atari_arg_parser().parse_known_args()[0]
        pi = fit(platform,
                 args.env,
                 num_timesteps=args.num_timesteps,
                 seed=args.seed)

    # mujoco
    if platform == 'mujoco':
        args = mujoco_arg_parser().parse_known_args()[0]
        logger.configure()
        pi = fit(platform,
                 args.env,
                 num_timesteps=args.num_timesteps,
                 seed=args.seed)

    # robotics
    if platform == 'robotics':
        args = robotics_arg_parser().parse_known_args()[0]
        pi = fit(platform,
                 args.env,
                 num_timesteps=args.num_timesteps,
                 seed=args.seed)

    # humanoids
    if platform == 'humanoid':
        logger.configure()
        parser = mujoco_arg_parser()
        parser.add_argument('--model-path',
                            default=os.path.join(logger.get_dir(),
                                                 'humanoid_policy'))
        parser.set_defaults(num_timesteps=int(2e7))

        args = parser.parse_known_args()[0]

        if not args.play:
            # train the model
            pi = fit(platform,
                     args.env,
                     num_timesteps=args.num_timesteps,
                     seed=args.seed,
                     model_path=args.model_path)
        else:
            # construct the model object, load pre-trained model and render
            from utils.cmd import make_mujoco_env
            pi = fit(platform, args.evn, num_timesteps=1, seed=args.seed)
            Model().load_state(args.model_path)
            env = make_mujoco_env('Humanoid-v2', seed=0)

            ob = env.reset()
            while True:
                action = pi.act(stochastic=False, ob=ob)[0]
                ob, _, done, _ = env.step(action)
                env.render()
                if done:
                    ob = env.reset()
コード例 #13
0
def fit(environ, env_id, num_timesteps, seed, model_path=None):
    # atari
    if environ == 'atari':
        rank = MPI.COMM_WORLD.Get_rank()
        sess = Model().single_threaded_session()
        sess.__enter__()
        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])
        workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() if seed \
            is not None else None
        set_global_seeds(workerseed)
        env = make_atari(env_id)

        def policy_fn(name, ob_space, ac_space):
            return PPO1Cnn(name=name, ob_space=ob_space, ac_space=ac_space)

        env = Monitor(
            env,
            logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
        env.seed(workerseed)

        env = wrap_deepmind(env)
        env.seed(workerseed)

        pi = PPOSGD(env,
                    policy_fn,
                    env.observation_space,
                    env.action_space,
                    timesteps_per_actorbatch=256,
                    clip_param=0.2,
                    entcoeff=0.01,
                    optim_epochs=4,
                    optim_stepsize=1e-3,
                    optim_batchsize=64,
                    gamma=0.99,
                    lam=0.95,
                    max_timesteps=int(num_timesteps * 1.1),
                    schedule='linear')

        env.close()
        sess.close()
        return pi

    # mujoco
    if environ == 'mujoco':
        from utils.cmd import make_mujoco_env

        sess = Model().init_session(num_cpu=1).__enter__()

        def policy_fn(name, ob_space, ac_space):
            return PPO1Mlp(name=name,
                           ob_space=ob_space,
                           ac_space=ac_space,
                           hid_size=64,
                           num_hid_layers=2)

        env = make_mujoco_env(env_id, seed)
        pi = PPOSGD(
            env,
            policy_fn,
            env.observation_space,
            env.action_space,
            max_timesteps=num_timesteps,
            timesteps_per_actorbatch=2048,
            clip_param=0.2,
            entcoeff=0.0,
            optim_epochs=10,
            optim_stepsize=3e-4,
            optim_batchsize=64,
            gamma=0.99,
            lam=0.95,
            schedule='linear',
        )
        env.close()
        sess.close()
        return pi

    if environ == 'humanoid':
        import gym
        from utils.cmd import make_mujoco_env

        env_id = 'Humanoid-v2'

        class RewScale(gym.RewardWrapper):
            def __init__(self, env, scale):
                gym.RewardWrapper.__init__(self, env)
                self.scale = scale

            def reward(self, r):
                return r * self.scale

        sess = Model().init_session(num_cpu=1).__enter__()

        def policy_fn(name, ob_space, ac_space):
            return PPO1Mlp(name=name,
                           ob_space=ob_space,
                           ac_space=ac_space,
                           hid_size=64,
                           num_hid_layers=2)

        env = make_mujoco_env(env_id, seed)

        # parameters below were the best found in a simple random
        # search these are good enough to make humanoid walk, but
        # whether those are an absolute best or not is not certain
        env = RewScale(env, 0.1)
        pi = PPOSGD(
            env,
            policy_fn,
            env.observation_space,
            env.action_space,
            max_timesteps=num_timesteps,
            timesteps_per_actorbatch=2048,
            clip_param=0.2,
            entcoeff=0.0,
            optim_epochs=10,
            optim_stepsize=3e-4,
            optim_batchsize=64,
            gamma=0.99,
            lam=0.95,
            schedule='linear',
        )
        env.close()
        if model_path:
            Model().save_state(model_path)

        sess.close()
        return pi

    if environ == 'robotics':
        import mujoco_py
        from utils.cmd import make_robotics_env
        rank = MPI.COMM_WORLD.Get_rank()
        sess = Model().single_threaded_session()
        sess.__enter__()
        mujoco_py.ignore_mujoco_warnings().__enter__()
        workerseed = seed + 10000 * rank
        set_global_seeds(workerseed)
        env = make_robotics_env(env_id, workerseed, rank=rank)

        def policy_fn(name, ob_space, ac_space):
            return PPO1Mlp(name=name,
                           ob_space=ob_space,
                           ac_space=ac_space,
                           hid_size=256,
                           num_hid_layers=3)

        pi = PPOSGD(
            env,
            policy_fn,
            env.observation_space,
            env.action_space,
            max_timesteps=num_timesteps,
            timesteps_per_actorbatch=2048,
            clip_param=0.2,
            entcoeff=0.0,
            optim_epochs=5,
            optim_stepsize=3e-4,
            optim_batchsize=256,
            gamma=0.99,
            lam=0.95,
            schedule='linear',
        )
        env.close()
        sess.close()
        return pi
コード例 #14
0
    assert v['obj'] in ['fkl', 'rkl', 'js', 'emd', 'maxentirl']
    assert v['IS'] == False

    # logs
    exp_id = f"logs/{env_name}/exp-{num_expert_trajs}/{v['obj']}" # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/' + now.strftime('%Y_%m_%d_%H_%M_%S')
    logger.configure(dir=log_folder)            
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp firl/irl_samples.py {log_folder}')
    os.system(f'cp {sys.argv[1]} {log_folder}/variant_{pid}.yml')
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    print('pid', pid)
    os.makedirs(os.path.join(log_folder, 'plt'))
    os.makedirs(os.path.join(log_folder, 'model'))

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # load expert samples from trained policy
    expert_trajs = torch.load(f'expert_data/states/{env_name}.pt').numpy()[:, :, state_indices]
コード例 #15
0
ファイル: run_trpo.py プロジェクト: kirk86/baselines
def main():
    parser = arg_parser()
    parser.add_argument('--platform',
                        help='environment choice',
                        choices=['atari', 'mujoco'],
                        default='atari')

    platform_args, environ_args = parser.parse_known_args()
    platform = platform_args.platform

    rank = MPI.COMM_WORLD.Get_rank()

    # atari
    if platform == 'atari':
        from bench import Monitor
        from utils.cmd import atari_arg_parser, make_atari, \
            wrap_deepmind
        from policies.nohashingcnn import CnnPolicy

        args = atari_arg_parser().parse_known_args()[0]
        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])

        workerseed = args.seed + 10000 * rank
        set_global_seeds(workerseed)
        env = make_atari(args.env)

        env = Monitor(
            env,
            logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
        env.seed(workerseed)

        env = wrap_deepmind(env)
        env.seed(workerseed)

        model = TRPO(CnnPolicy, env.observation_space, env.action_space)
        sess = model.single_threaded_session().__enter__()
        # model.reset_graph_and_vars()
        model.init_vars()

        fit(model,
            env,
            timesteps_per_batch=512,
            max_kl=0.001,
            cg_iters=10,
            cg_damping=1e-3,
            max_timesteps=int(args.num_timesteps * 1.1),
            gamma=0.98,
            lam=1.0,
            vf_iters=3,
            vf_stepsize=1e-4,
            entcoeff=0.00)
        sess.close()
        env.close()

    # mujoco
    if platform == 'mujoco':
        from policies.ppo1mlp import PPO1Mlp
        from utils.cmd import make_mujoco_env, mujoco_arg_parser
        args = mujoco_arg_parser().parse_known_args()[0]

        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])
            logger.set_level(logger.DISABLED)

        workerseed = args.seed + 10000 * rank

        env = make_mujoco_env(args.env, workerseed)

        def policy(name, observation_space, action_space):
            return PPO1Mlp(name,
                           env.observation_space,
                           env.action_space,
                           hid_size=32,
                           num_hid_layers=2)

        model = TRPO(policy, env.observation_space, env.action_space)
        sess = model.single_threaded_session().__enter__()
        model.init_vars()

        fit(model,
            env,
            timesteps_per_batch=1024,
            max_kl=0.01,
            cg_iters=10,
            cg_damping=0.1,
            max_timesteps=args.num_timesteps,
            gamma=0.99,
            lam=0.98,
            vf_iters=5,
            vf_stepsize=1e-3)
        sess.close()
        env.close()
コード例 #16
0
ファイル: run_acktr.py プロジェクト: kirk86/baselines
def fit(policy,
        env,
        seed,
        total_timesteps=int(40e6),
        gamma=0.99,
        log_interval=1,
        nprocs=32,
        nsteps=20,
        ent_coef=0.01,
        vf_coef=0.5,
        vf_fisher_coef=1.0,
        lr=0.25,
        max_grad_norm=0.5,
        kfac_clip=0.001,
        save_interval=None,
        lrschedule='linear'):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    model = AcktrDiscrete(policy,
                          ob_space,
                          ac_space,
                          nenvs,
                          total_timesteps,
                          nsteps=nsteps,
                          ent_coef=ent_coef,
                          vf_coef=vf_fisher_coef,
                          lr=lr,
                          max_grad_norm=max_grad_norm,
                          kfac_clip=kfac_clip,
                          lrschedule=lrschedule)
    # if save_interval and logger.get_dir():
    #     import cloudpickle
    #     with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
    #         fh.write(cloudpickle.dumps(make_model))
    # model = make_model()

    runner = Environment(env, model, nsteps=nsteps, gamma=gamma)
    nbatch = nenvs * nsteps
    tstart = time.time()
    coord = tf.train.Coordinator()
    enqueue_threads = model.q_runner.create_threads(model.sess,
                                                    coord=coord,
                                                    start=True)
    for update in range(1, total_timesteps // nbatch + 1):
        obs, states, rewards, masks, actions, values = runner.run()
        policy_loss, value_loss, policy_entropy = model.train(
            obs, states, rewards, masks, actions, values)
        model.old_obs = obs
        nseconds = time.time() - tstart
        fps = int((update * nbatch) / nseconds)
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, rewards)
            logger.record_tabular("nupdates", update)
            logger.record_tabular("total_timesteps", update * nbatch)
            logger.record_tabular("fps", fps)
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("policy_loss", float(policy_loss))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("explained_variance", float(ev))
            logger.dump_tabular()

        if save_interval and (update % save_interval == 0 or update == 1) \
           and logger.get_dir():
            savepath = os.path.join(logger.get_dir(),
                                    'checkpoint%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    coord.request_stop()
    coord.join(enqueue_threads)
    env.close()