コード例 #1
0
ファイル: train_a2c.py プロジェクト: amnonh-uw/block-world
def train(make_env,
          num_timesteps,
          seed,
          policy,
          lrschedule,
          num_cpu,
          vf_coef=0.5,
          ent_coef=0.01):
    def _make_env(rank):
        def _thunk():
            env = make_env()
            env.seed(seed + rank)
            return env

        return _thunk

    set_global_seeds(seed)
    env = SubprocVecEnv([_make_env(i) for i in range(num_cpu)])

    learn(policy,
          env,
          seed,
          nstack=1,
          total_timesteps=num_timesteps,
          lrschedule=lrschedule,
          vf_coef=vf_coef,
          ent_coef=ent_coef)
    env.close()
コード例 #2
0
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env):
    """
    Train A2C model for atari environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
    :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
                                 'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param num_env: (int) The number of environments
    """
    policy_fn = None
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    if policy_fn is None:
        raise ValueError("Error: policy {} not implemented".format(policy))

    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lr_schedule=lr_schedule)
    env.close()
コード例 #3
0
def train(num_timesteps,
          env_name,
          seed,
          policy,
          lrschedule,
          num_env,
          entrophy,
          lr,
          save_name=None):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    elif policy == 'i2a':
        policy_fn = I2ANetwork
    env = VecFrameStack(make_doom_env(num_env, 0, env_name), 4)
    if save_name is None:
        save_name = env_name
    learn(policy_fn,
          env,
          seed,
          save_name=save_name,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule,
          log_interval=500,
          save_interval=1000,
          cont=True,
          ent_coef=entrophy,
          lr=lr)
    env.close()
コード例 #4
0
def train(env_id, num_frames, seed, policy, lrschedule, num_cpu):
    num_timesteps = int(num_frames / 4 * 1.1)

    # divide by 4 due to frameskip, then do a little extras so episodes end
    def make_env(rank):
        def _thunk():
            env = gym.make(env_id)
            env.seed(seed + rank)
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(
                    logger.get_dir(), "{}.monitor.json".format(rank)))
            gym.logger.setLevel(logging.WARN)
            return wrap_deepmind(env)

        return _thunk

    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    learn(policy_fn,
          env,
          seed,
          total_timesteps=num_timesteps,
          lrschedule=lrschedule)
    env.close()
コード例 #5
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env, ckpt_path,
          hparams):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    elif policy == 'cnn_attention':
        policy_fn = CnnAttentionPolicy

    video_log_dir = os.path.join(hparams['base_dir'], 'videos',
                                 hparams['experiment_name'])
    env = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       video_log_dir=video_log_dir,
                       write_attention_video='attention' in policy,
                       hparams=hparams), 4)

    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule,
          ckpt_path=ckpt_path,
          hparams=hparams)
    env.close()
コード例 #6
0
def train():
    env_args = dict(map_name=FLAGS.map_name,
                    step_mul=FLAGS.step_mul,
                    game_steps_per_episode=0,
                    screen_size_px=(FLAGS.resolution, ) * 2,
                    minimap_size_px=(FLAGS.resolution, ) * 2,
                    visualize=FLAGS.visualize)

    envs = SubprocVecEnv(
        [partial(make_sc2env, id=i, **env_args) for i in range(FLAGS.n_envs)])
    policy_fn = FullyConvPolicy
    try:
        learn(
            policy_fn,
            envs,
            seed=1,
            total_timesteps=int(1e6) * FLAGS.frames,
            lrschedule=FLAGS.lrschedule,
            nstack=1,  #must be 1 for FullyConvPolicy above
            ent_coef=FLAGS.entropy_weight,
            vf_coef=FLAGS.value_weight,
            max_grad_norm=1.0,
            lr=FLAGS.learning_rate)
    except KeyboardInterrupt:
        pass

    envs.close()
コード例 #7
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
    def make_env(rank):
        def _thunk():
            env = make_atari(env_id)
            env.seed(seed + rank)
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            gym.logger.setLevel(logging.WARN)
            return wrap_deepmind(env)

        return _thunk

    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
コード例 #8
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu, continuous_actions=False, numAgents=2, benchmark=False):
    # Create environment
    # env = envVec([make_env(idx, benchmark) for idx in range(num_cpu)])
    env = make_env(env_id, benchmark)
    # print('action space: ', env.action_space)
    # env = GymVecEnv([make_env(idx) for idx in range(num_cpu)])
    policy_fn = policy_fn_name(policy)
    learn(policy_fn, env, seed, nsteps=128, nstack=1, total_timesteps=int(num_timesteps * 1.1), lr=1e-2, lrschedule=lrschedule, continuous_actions=continuous_actions, numAgents=numAgents, continueTraining=False, debug=False, particleEnv=True, model_name='partEnv_model_')
コード例 #9
0
ファイル: run_atari.py プロジェクト: MulixBF/baselines
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    if policy == u'cnn':
        policy_fn = CnnPolicy
    elif policy == u'lstm':
        policy_fn = LstmPolicy
    elif policy == u'lnlstm':
        policy_fn = LnLstmPolicy
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
コード例 #10
0
ファイル: run_atari.py プロジェクト: Divyankpandey/baselines
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
コード例 #11
0
def train(env_id,
          num_frames,
          seed,
          nsteps,
          policy,
          lrschedule,
          num_cpu,
          model_path,
          lr=7e-4,
          pg_coef=1.0,
          ent_coef=0.01,
          vf_coef=0.5):
    num_timesteps = int(num_frames / 4)

    # divide by 4 due to frameskip
    def make_env(rank, isTraining=True):
        def _thunk():
            env = gym.make(env_id)
            env.seed(seed + rank)
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(
                    logger.get_dir(), "{}.monitor.json".format(rank)),
                allow_early_resets=(not isTraining))
            gym.logger.setLevel(logging.WARN)
            return wrap_deepmind(env,
                                 episode_life=isTraining,
                                 clip_rewards=isTraining)

        return _thunk

    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i, isTraining=True) for i in range(num_cpu)])
    eval_env = SubprocVecEnv(
        [make_env(num_cpu + i, isTraining=False) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    learn(policy_fn,
          env,
          eval_env,
          seed,
          nsteps=nsteps,
          total_timesteps=num_timesteps,
          lr=lr,
          pg_coef=pg_coef,
          ent_coef=ent_coef,
          vf_coef=vf_coef,
          lrschedule=lrschedule,
          model_path=model_path)
    eval_env.close()
    env.close()
コード例 #12
0
def train(env_id, num_timesteps, seed, lrschedule, num_env):
    def make_env():
        env = gym.make(env_id)
        env = bench.Monitor(env, logger.get_dir(), allow_early_resets=True)
        return env

    env = DummyVecEnv([make_env])
    env = VecNormalize(env)

    set_global_seeds(seed)
    policy_fn = MlpPolicy

    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
コード例 #13
0
def main():
    num_env = 1
    env_id = "CartPole-v1"
    env_type = "classic_control"
    seed = None

    env = make_vec_env(env_id,
                       env_type,
                       num_env,
                       seed,
                       wrapper_kwargs=None,
                       start_index=0,
                       reward_scale=1.0,
                       flatten_dict_observations=True,
                       gamestate=None)

    act = a2c.learn(env=env,
                    network='mlp',
                    total_timesteps=0,
                    load_path="cartpole_model.pkl")

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            obs, rew, done, _ = env.step(act(obs[None])[0])
            episode_rew += rew
        print("Episode reward", episode_rew)
コード例 #14
0
ファイル: run_atari.py プロジェクト: hitersyw/baselines
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env,
          replay_lambda=1, replay_loss=None, ss_rate=1, thetas=None):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    if replay_loss is not None:
        learn_staged(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule, replay_lambda=replay_lambda, ss_rate=ss_rate,
         replay_loss=replay_loss, thetas=thetas)
    else:
        learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
コード例 #15
0
def main():

    args = parse_args()

    format_strs = ['log', 'csv', 'stdout']

    if args.tensorboard:
        format_strs.append('tensorboard')

    config = parse_config(args.config)

    outdir = os.path.join(args.outdir,
                          os.path.splitext(os.path.basename(args.config))[0])
    logger.configure(dir=outdir, format_strs=format_strs)

    env_type, env_id = get_env_type(GAME_ENVIRONMENT)
    env = make_vec_env(env_id, env_type, 1, args.seed)

    model = a2c.learn(env=env,
                      network=NETWORK_ARCHITECTURE,
                      total_timesteps=args.total_timesteps,
                      **config)

    env.close()

    if args.save:
        model.save(os.path.join(outdir, 'model'))
コード例 #16
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    env = MultiWrapper()
    policy_fn = LnLstmPolicy
    # env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    print("Important!", env.action_space)
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
コード例 #17
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env, env_name):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    elif policy == 'mlp':
        policy_fn = MlpPolicy
    env = make_gym_control_multi_env(env_id, num_env, seed)
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule,
          env_name=env_name)
    env.close()
コード例 #18
0
ファイル: run_atari.py プロジェクト: Badhreesh/Thesis
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    print('train() called')
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    env = VecFrameStack(make_atari_env(env_id, num_env, seed),
                        4)  # Make "num_env" environments

    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)  # Learn
    env.close()
コード例 #19
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env,
          v_ex_coef, r_ex_coef, r_in_coef, lr_alpha, lr_beta, no_ex, no_in):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    elif policy == 'cnn_int':
        policy_fn = CnnPolicyIntrinsicReward
    else:
        raise NotImplementedError
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.01), lrschedule=lrschedule,
          v_ex_coef=v_ex_coef, r_ex_coef=r_ex_coef, r_in_coef=r_in_coef,
          lr_alpha=lr_alpha, lr_beta=lr_beta, no_ex=no_ex, no_in=no_in)
    env.close()
コード例 #20
0
ファイル: a2c.py プロジェクト: lychanl/duck-driving-golem
def train(env, save_path, nsteps=20, timesteps=1e3):
    model = a2c.learn(
        a2c_discrete_cnn,
        env,
        nsteps=nsteps,
        total_timesteps=int(timesteps),
        load_path=save_path if os.path.isfile(save_path) else None)
    model.save(save_path)
コード例 #21
0
def train():
    # Fetch the requested environment set in flags.
    env_class = attrgetter(FLAGS.env)(sc2g.env)

    env_args = dict(
        map_name=FLAGS.map_name,
        feature_screen_size=FLAGS.screen_size,
        feature_minimap_size=FLAGS.minimap_size,
        visualize=FLAGS.visualize,
        save_replay_episodes=FLAGS.save_replay_episodes,
        replay_dir=FLAGS.replay_dir,
    )

    envs = SubprocVecEnv([
        partial(env_class.make_env, id=i, **env_args)
        for i in range(FLAGS.envs)
    ])

    policy_fn = CnnPolicy
    if FLAGS.policy == 'cnn':
        policy_fn = CnnPolicy
    elif FLAGS.policy == 'lstm':
        policy_fn = LstmPolicy
    elif FLAGS.policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    elif FLAGS.policy == 'fullyconv':
        policy_fn = FullyConvPolicy
    else:
        print("Invalid policy function! Defaulting to {}.".format(policy_fn))

    try:
        learn(policy_fn,
              envs,
              seed=1,
              total_timesteps=int(1e6 * FLAGS.max_timesteps),
              lrschedule=FLAGS.lrschedule,
              ent_coef=FLAGS.entropy_weight,
              vf_coef=FLAGS.value_weight,
              max_grad_norm=1.0,
              lr=FLAGS.learning_rate)
    except KeyboardInterrupt:
        pass

    print("Closing environment...")
    envs.close()
コード例 #22
0
ファイル: run_doom.py プロジェクト: Badhreesh/Thesis
def train(doom_lvl, num_timesteps, seed, policy, lrschedule, num_env, adda_lr,
          adda_batch, training, use_adda):
    print('train() called')
    if policy == 'cnn':
        policy_fn = CnnPolicy

    env = make_doom_env(doom_lvl, num_env, seed)  # Make "num_env" environments

    learn(policy_fn,
          env,
          seed,
          training,
          use_adda,
          adda_lr,
          adda_batch,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
コード例 #23
0
def train(env_id,
          num_timesteps,
          seed,
          policy,
          lrschedule,
          num_cpu,
          continuous_actions=False,
          numAgents=2,
          benchmark=False):
    # Create environment
    test = True
    communication = False
    if env_id == 'simple_reference':
        communication = True
        test = False
    env = EnvVec([
        make_env(env_id, benchmark=benchmark, rank=idx, seed=seed)
        for idx in range(num_cpu)
    ],
                 particleEnv=True,
                 test=test,
                 communication=communication,
                 env_id=env_id)
    # env = make_env(env_id, benchmark)
    # print('action space: ', env.action_space)
    # env = GymVecEnv([make_env(idx) for idx in range(num_cpu)])
    policy_fn = policy_fn_name(policy)
    learn(policy_fn,
          env,
          seed,
          nsteps=5,
          nstack=1,
          total_timesteps=int(num_timesteps * 1.1),
          lr=2e-3,
          lrschedule=lrschedule,
          continuous_actions=continuous_actions,
          numAgents=numAgents,
          continueTraining=False,
          debug=True,
          particleEnv=True,
          model_name='partEnv_model_',
          log_interval=50,
          communication=communication)
コード例 #24
0
ファイル: run_mapEnv.py プロジェクト: scottbarnesg/baselines
def train(env_id,
          num_timesteps,
          seed,
          policy,
          lrschedule,
          num_cpu,
          continuous_actions=False,
          numAgents=2):
    def make_env(rank):
        def _thunk():
            env = gym.make(env_id)
            env.seed(seed + rank)
            env.ID = rank
            # print("logger dir: ", logger.get_dir())
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            if env_id == 'Pendulum-v0':
                if continuous_actions:
                    env.action_space.n = env.action_space.shape[0]
                else:
                    env.action_space.n = 10
            gym.logger.setLevel(logging.WARN)
            return env

        return _thunk

    env = GymVecEnv([make_env(idx) for idx in range(num_cpu)])
    policy_fn = policy_fn_name(policy)
    learn(policy_fn,
          env,
          seed,
          nsteps=30,
          nstack=1,
          total_timesteps=int(num_timesteps * 1.1),
          lr=7e-4,
          lrschedule=lrschedule,
          continuous_actions=continuous_actions,
          numAgents=numAgents,
          continueTraining=True,
          debug=False,
          particleEnv=False)
コード例 #25
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    # VecFrameStack
    # make_atari_env() : launches 'num_env' subprocess each with 'env_id' and for i in num_env: seed+=seed+i
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    print("~~~~~~~~~~~~~ run_atari: len(env): " + str(env.nstack))
    print("~~~~~~~~~~~~~ run_atari: str(env): " + str(env))
    # above prints : run_atari: str(env): <baselines.common.vec_env.vec_frame_stack.VecFrameStack object at 0x1c22ee06d8>
    print("_____________________________________________ policy: " +
          str(policy))
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
コード例 #26
0
ファイル: run_atari.py プロジェクト: IcarusTan/baselines
def train(env_id, num_frames, seed, policy, lrschedule, num_cpu):
    num_timesteps = int(num_frames / 4 * 1.1) 
    # divide by 4 due to frameskip, then do a little extras so episodes end
    def make_env(rank):
        def _thunk():
            env = gym.make(env_id)
            env.seed(seed + rank)
            env = bench.Monitor(env, logger.get_dir() and 
                os.path.join(logger.get_dir(), "{}.monitor.json".format(rank)))
            gym.logger.setLevel(logging.WARN)
            return wrap_deepmind(env)
        return _thunk
    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    learn(policy_fn, env, seed, total_timesteps=num_timesteps, lrschedule=lrschedule)
    env.close()
コード例 #27
0
ファイル: run_doom.py プロジェクト: akashin/baselines
def train(env_id, num_frames, seed, policy, lrschedule, num_cpu):
    num_timesteps = int(num_frames / 4 * 1.1)

    # divide by 4 due to frameskip, then do a little extras so episodes end
    def make_env(rank):
        def _thunk():
            env_spec = gym.spec('ppaquette/DoomBasic-v0')
            env_spec.id = 'DoomBasic-v0'
            env = env_spec.make()
            env.seed(seed + rank)
            env = PreprocessImage((SkipWrapper(4)(ToDiscrete("minimal")(env))))
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(
                    logger.get_dir(), "{}.monitor.json".format(rank)))
            gym.logger.setLevel(logging.WARN)
            return ScaleRewardEnv(env)

        return _thunk

    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    learn(policy_fn,
          env,
          seed,
          total_timesteps=num_timesteps,
          lrschedule=lrschedule,
          lr=1e-4,
          nsteps=10,
          nstack=1)
    env.close()
コード例 #28
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env, param):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=ncpu,
                            inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    tf.Session(config=config).__enter__()
    # change parameter of env to start multi envs
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule,
          param=param,
          nsteps=16)
    env.close()
コード例 #29
0
def train(num_timesteps, seed, policy, lrschedule, num_cpu):
    # TODO: Just f****n ugly handle that better
    def make_env(rank):
        def _thunk():
            print(rank)
            if num_cpu == 0:
                env = MarioEnv(num_steering_dir=11)
            else:
                env = MarioEnv(num_steering_dir=11, num_env=rank)
            env.seed(seed + rank)
            env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            gym.logger.setLevel(logging.WARN)
            return env
        return _thunk
    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    learn(policy_fn, env, seed, nsteps=128, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
コード例 #30
0
ファイル: a2c_open.py プロジェクト: sterlingsomers/cnn_drone
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    #env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
    env = VecFrameStack(make_custom_env('gridworld-v0', num_env, seed), 1)
    act = learn(policy_fn,
                env,
                seed,
                total_timesteps=int(num_timesteps * 1.1),
                lrschedule=lrschedule)
    act.save('a2c_bopen.pkl')
    env.close()
コード例 #31
0
def main():
    num_env = 1
    env_id = "CartPole-v1"
    env_type = "classic_control"
    seed = None

    env = make_vec_env(env_id,
                       env_type,
                       num_env,
                       seed,
                       wrapper_kwargs=None,
                       start_index=0,
                       reward_scale=1.0,
                       flatten_dict_observations=True,
                       gamestate=None)

    act = a2c.learn(env=env, network='mlp', total_timesteps=80000)
    print("Saving model to cartpole_model.pkl")
    act.save("cartpole_model.pkl")
コード例 #32
0
ファイル: train.py プロジェクト: ChrisFugl/DoomRL
def train(config, env, logger):
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    intra_op_parallelism_threads=1,
                                    inter_op_parallelism_threads=1)
    session_config.gpu_options.allow_growth = True
    get_session(config=session_config)
    model = learn(env=env,
                  total_timesteps=config.timesteps,
                  network='cnn',
                  lr=config.learning_rate,
                  alpha=config.rmsp_decay,
                  gamma=config.discount_factor,
                  nsteps=config.number_of_steps,
                  epsilon=config.rmsp_epsilon,
                  max_grad_norm=config.max_grad_norm,
                  ent_coef=config.entropy_weight,
                  vf_coef=config.critic_weight,
                  log_interval=config.log_every)
    model.save(config.save_path)
コード例 #33
0
import pytest
import tensorflow as tf
import random
import numpy as np
from gym.spaces import np_random

from baselines.a2c import a2c
from baselines.ppo2 import ppo2
from baselines.common.identity_env import IdentityEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.ppo2.policies import MlpPolicy


learn_func_list = [
    lambda e: a2c.learn(policy=MlpPolicy, env=e, seed=0, total_timesteps=50000),
    lambda e: ppo2.learn(policy=MlpPolicy, env=e, total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.01)
]


@pytest.mark.slow
@pytest.mark.parametrize("learn_func", learn_func_list)
def test_identity(learn_func):
    '''
    Test if the algorithm (with a given policy) 
    can learn an identity transformation (i.e. return observation as an action)
    '''
    np.random.seed(0)
    np_random.seed(0)
    random.seed(0)

    env = DummyVecEnv([lambda: IdentityEnv(10)])