Exemplo n.º 1
0
def test(game_name,
         num_timesteps,
         policy,
         load_path,
         save_path,
         noops=False,
         sticky=False,
         epsgreedy=False):
    import tensorflow as tf
    import horovod.tensorflow as hvd
    hvd.init()
    print('initialized worker %d' % hvd.rank(), flush=True)
    from baselines.common import set_global_seeds
    set_global_seeds(hvd.rank())
    from baselines import bench
    from baselines.common import set_global_seeds
    from atari_reset.wrappers import VecFrameStack, VideoWriter, my_wrapper,\
        EpsGreedyEnv, StickyActionEnv, NoopResetEnv, SubprocVecEnv
    from atari_reset.ppo import learn
    from atari_reset.policies import CnnPolicy, GRUPolicy

    set_global_seeds(hvd.rank())
    ncpu = 2
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=ncpu,
                            inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    tf.Session(config=config).__enter__()

    def make_env(rank):
        def env_fn():
            env = gym.make(game_name + 'NoFrameskip-v4')
            env = bench.Monitor(env, "{}.monitor.json".format(rank))
            if rank % nenvs == 0 and hvd.local_rank() == 0:
                os.makedirs('results/' + game_name, exist_ok=True)
                videofile_prefix = 'results/' + game_name
                env = VideoWriter(env, videofile_prefix)
            if noops:
                env = NoopResetEnv(env)
            if sticky:
                env = StickyActionEnv(env)
            env = my_wrapper(env, clip_rewards=True)
            if epsgreedy:
                env = EpsGreedyEnv(env)
            return env

        return env_fn

    nenvs = 8
    env = SubprocVecEnv(
        [make_env(i + nenvs * hvd.rank()) for i in range(nenvs)])
    env = VecFrameStack(env, 4)

    policy = {'cnn': CnnPolicy, 'gru': GRUPolicy}[policy]
    learn(policy=policy,
          env=env,
          nsteps=256,
          log_interval=1,
          save_interval=100,
          total_timesteps=num_timesteps,
          load_path=load_path,
          save_path=save_path,
          game_name=game_name,
          test_mode=True)
Exemplo n.º 2
0
def train(game_name, policy, num_timesteps, lr, entropy_coef, load_path,
          starting_point, save_path):
    import tensorflow as tf
    import horovod.tensorflow as hvd
    hvd.init()
    print('initialized worker %d' % hvd.rank(), flush=True)
    from baselines.common import set_global_seeds
    set_global_seeds(hvd.rank())
    from atari_reset.ppo import learn
    from atari_reset.policies import CnnPolicy, GRUPolicy
    from atari_reset.wrappers import ReplayResetEnv, ResetManager, SubprocVecEnv, VideoWriter, VecFrameStack, my_wrapper

    ncpu = 2
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=ncpu,
                            inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    tf.Session(config=config).__enter__()

    nrstartsteps = 320  # number of non frameskipped steps to divide workers over
    nenvs = 16
    nrworkers = hvd.size() * nenvs
    workers_per_sp = int(np.ceil(nrworkers / nrstartsteps))

    def make_env(rank):
        def env_fn():
            env = gym.make(game_name + 'NoFrameskip-v4')
            env = ReplayResetEnv(env,
                                 demo_file_name='demos/' + game_name + '.demo',
                                 seed=rank,
                                 workers_per_sp=workers_per_sp)
            if rank % nenvs == 0 and hvd.local_rank(
            ) == 0:  # write videos during training to track progress
                dir = os.path.join(save_path, game_name)
                os.makedirs(dir, exist_ok=True)
                videofile_prefix = os.path.join(dir, 'episode')
                env = VideoWriter(env, videofile_prefix)
            env = my_wrapper(env, clip_rewards=True)
            return env

        return env_fn

    env = SubprocVecEnv(
        [make_env(i + nenvs * hvd.rank()) for i in range(nenvs)])
    env = ResetManager(env)
    env = VecFrameStack(env, 4)

    if starting_point is not None:
        env.set_max_starting_point(starting_point)

    policy = {'cnn': CnnPolicy, 'gru': GRUPolicy}[policy]
    learn(policy=policy,
          env=env,
          nsteps=128,
          lam=.95,
          gamma=.999,
          noptepochs=4,
          log_interval=1,
          save_interval=100,
          ent_coef=entropy_coef,
          l2_coef=1e-7,
          lr=lr,
          cliprange=0.1,
          total_timesteps=num_timesteps,
          norm_adv=True,
          load_path=load_path,
          save_path=save_path,
          game_name=game_name)
Exemplo n.º 3
0
def train(args, extra_data):
    import filelock
    with filelock.FileLock('/tmp/robotstify.lock'):
        import gym
        import sys
        try:
            import goexplore_py.complex_fetch_env
        except Exception:
            print(
                'Could not import complex_fetch_env, is goexplore_py in PYTHONPATH?'
            )

    import tensorflow as tf
    import horovod.tensorflow as hvd
    hvd.init()
    print('initialized worker %d' % hvd.rank(), flush=True)
    if hvd.rank() == 0:
        while os.path.exists(args.save_path + '/progress.csv'):
            while args.save_path[-1] == '/':
                args.save_path = args.save_path[:-1]
            args.save_path += '_retry'
            # assert False, 'The save path already exists, something is wrong. If retrying the job, please clear this manually first.'
        logger.configure(args.save_path)
        os.makedirs(args.save_path + '/' + args.game, exist_ok=True)
        for k in list(extra_data):
            if 'prev_progress' in k:
                extra_data[k].to_csv(args.save_path + '/' + k + '.csv',
                                     index=False)
                del extra_data[k]

    frameskip = 1 if 'fetch' in args.game else 4

    if args.autoscale is not None:
        max_reward = get_mean_reward(
            args.demo, args.autoscale_fn, frameskip,
            (args.gamma if args.autoscale_value else None)) / args.autoscale
        args.scale_rewards = 1.0 / max_reward
        print(
            f'Autoscaling with scaling factor 1 / {max_reward} ({args.scale_rewards})'
        )
        args.clip_rewards = False

    if 'Pitfall' in args.game and not args.scale_rewards:
        print('Forcing reward scaling because game is Pitfall!')
        args.scale_rewards = 0.001
        args.clip_rewards = False

    import json
    os.makedirs(args.save_path, exist_ok=True)
    json.dump(vars(args),
              open(args.save_path + '/kwargs.json', 'w'),
              indent=True,
              sort_keys=True)
    from baselines.common import set_global_seeds
    set_global_seeds(hvd.rank())
    from atari_reset.ppo import learn
    from atari_reset.policies import CnnPolicy, GRUPolicy, FFPolicy, FetchCNNPolicy
    from atari_reset.wrappers import ReplayResetEnv, ResetManager, SubprocVecEnv, VideoWriter, VecFrameStack, SuperDumbVenvWrapper, my_wrapper, MyResizeFrame, WarpFrame, MyResizeFrameOld, TanhWrap, SubprocWrapper, prepare_subproc

    if args.frame_resize == "MyResizeFrame":
        frame_resize_wrapper = MyResizeFrame
    elif args.frame_resize == "WarpFrame":
        frame_resize_wrapper = WarpFrame
    elif args.frame_resize == "MyResizeFrameOld":
        frame_resize_wrapper = MyResizeFrameOld
    else:
        raise NotImplementedError("No such frame-size wrapper: " +
                                  args.frame_resize)
    ncpu = 2
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=ncpu,
                            inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    tf.Session(config=config).__enter__()

    # nrstartsteps = 320  # number of non frameskipped steps to divide workers over
    nrworkers = hvd.size() * args.nenvs
    workers_per_sp = int(np.ceil(nrworkers / args.nrstartsteps))

    if args.demo is None:
        args.demo = 'demos/' + args.game + '.demo'
    print('Using demo', args.demo)

    subproc_data = None

    def make_env(rank, is_extra_sil, subproc_idx):
        # print('WOW', rank, is_extra_sil)
        def env_fn():
            if args.game == 'fetch':
                assert args.fetch_target_location is not None, 'For now, we require a target location for fetch'
                kwargs = {}
                dargs = vars(args)
                for attr in dargs:
                    if attr.startswith('fetch_'):
                        if attr == 'fetch_type':
                            kwargs[
                                'model_file'] = f'teleOp_{args.fetch_type}.xml'
                        elif attr != 'fetch_total_timestep':
                            kwargs[attr[len('fetch_'):]] = dargs[attr]

                env = goexplore_py.complex_fetch_env.ComplexFetchEnv(**kwargs)
            elif args.game == 'fetch_dumb':
                env = goexplore_py.dumb_fetch_env.ComplexFetchEnv(
                    incl_extra_full_state=args.fetch_incl_extra_full_state)
            else:
                env = gym.make(args.game + 'NoFrameskip-v4')
            env = ReplayResetEnv(env,
                                 args,
                                 seed=rank,
                                 workers_per_sp=workers_per_sp,
                                 is_extra_sil=is_extra_sil,
                                 frameskip=frameskip)
            if 'fetch' not in args.game:
                if rank % args.nenvs == 0 and hvd.local_rank(
                ) == 0:  # write videos during training to track progress
                    dir = os.path.join(args.save_path, args.game)
                    os.makedirs(dir, exist_ok=True)
                    if args.videos:
                        videofile_prefix = os.path.join(dir, 'episode')
                        env = VideoWriter(env, videofile_prefix)
                env = my_wrapper(
                    env,
                    #  clip_rewards=args.clip_rewards,
                    frame_resize_wrapper=frame_resize_wrapper,
                    #  scale_rewards=args.scale_rewards,
                    sticky=args.sticky)
            else:
                env = TanhWrap(env)
            return env

        return env_fn

    env_types = [(i, False) for i in range(args.nenvs)]
    if args.n_sil_envs:
        # For cases where we start from the current starting points
        env_types += [(args.nenvs - 1, True)] * args.n_sil_envs
        # For cases where we start from the beginning
        # env_types += [(0, True)] * n_sil_envs

    env = SubprocVecEnv([
        make_env(i + args.nenvs * hvd.rank(), is_extra_sil, subproc_idx)
        for subproc_idx, (i, is_extra_sil) in enumerate(env_types)
    ])
    env = ResetManager(
        env,
        move_threshold=args.move_threshold,
        steps_per_demo=args.steps_per_demo,
        fast_increase_starting_point=args.fast_increase_starting_point)
    if args.starting_points is not None:
        for i, e in enumerate(args.starting_points.split(',')):
            env.set_max_starting_point(
                int(e), i, args.move_threshold if args.sp_set_mt else 0)
    if 'fetch' not in args.game:
        env = VecFrameStack(env, frameskip)
    else:
        env = SuperDumbVenvWrapper(env)

    print('About to start PPO')
    if 'fetch' in args.game:
        if args.fetch_state_is_pixels:
            args.policy = FetchCNNPolicy
        else:
            print('Fetch environment, using the feedforward policy.')
            args.policy = FFPolicy
    else:
        args.policy = {'cnn': CnnPolicy, 'gru': GRUPolicy}[args.policy]
    args.im_cells = extra_data.get('im_cells')
    learn(env, args, False)
Exemplo n.º 4
0
def test(args):
    import filelock
    with filelock.FileLock('/tmp/robotstify.lock'):
        import gym
        import sys
        try:
            import goexplore_py.complex_fetch_env
        except Exception:
            print('Could not import complex_fetch_env, is goexplore_py in PYTHONPATH?')

    import tensorflow as tf
    import horovod.tensorflow as hvd
    hvd.init()
    print('initialized worker %d' % hvd.rank(), flush=True)
    from baselines.common import set_global_seeds
    set_global_seeds(hvd.rank())
    from baselines import bench
    from baselines.common import set_global_seeds
    from atari_reset.wrappers import VecFrameStack, VideoWriter, my_wrapper,\
        EpsGreedyEnv, StickyActionEnv, NoopResetEnv, SubprocVecEnv, PreventSlugEnv, FetchSaveEnv, TanhWrap
    from atari_reset.ppo import learn
    from atari_reset.policies import CnnPolicy, GRUPolicy, FFPolicy

    set_global_seeds(hvd.rank())
    ncpu = 2
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=ncpu,
                            inter_op_parallelism_threads=ncpu)
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    tf.Session(config=config).__enter__()

    max_noops = 30 if args.noops else 0
    print('SAVE PATH', args.save_path)

    def make_env(rank):
        def env_fn():
            if args.game == 'fetch':
                assert args.fetch_target_location is not None, 'For now, we require a target location for fetch'
                kwargs = {}
                dargs = vars(args)
                for attr in dargs:
                    if attr.startswith('fetch_'):
                        if attr == 'fetch_type':
                            kwargs['model_file'] = f'teleOp_{args.fetch_type}.xml'
                        elif attr != 'fetch_total_timestep':
                            kwargs[attr[len('fetch_'):]] = dargs[attr]

                env = goexplore_py.complex_fetch_env.ComplexFetchEnv(
                    **kwargs
                )
            elif args.game == 'fetch_dumb':
                env = goexplore_py.dumb_fetch_env.ComplexFetchEnv()
            else:
                env = gym.make(args.game + 'NoFrameskip-v4')
                if args.seed_env:
                    env.seed(0)
                # if args.unlimited_score:
                #     # This removes the TimeLimit wrapper around the env
                #     env = env.env
                # env = PreventSlugEnv(env)
            # change for long runs
            # env._max_episode_steps *= 1000
            env = bench.Monitor(env, "{}.monitor.json".format(rank), allow_early_resets=True)
            if False and rank%nenvs == 0 and hvd.local_rank()==0:
                os.makedirs(args.save_path + '/vids/' + args.game, exist_ok=True)
                videofile_prefix = args.save_path + '/vids/' + args.game
                env = VideoWriter(env, videofile_prefix)
            if 'fetch' not in args.game:
                if args.noops:
                    os.makedirs(args.save_path, exist_ok=True)
                    env = NoopResetEnv(env, 30, nenvs, args.save_path, num_per_noop=args.num_per_noop, unlimited_score=args.unlimited_score)
                    env = my_wrapper(env, clip_rewards=True, sticky=args.sticky)
                if args.epsgreedy:
                    env = EpsGreedyEnv(env)
            else:
                os.makedirs(f'{args.save_path}', exist_ok=True)
                env = FetchSaveEnv(env, rank=rank, n_ranks=nenvs, save_path=f'{args.save_path}/', demo_path=args.demo)
                env = TanhWrap(env)
            # def print_rec(e):
            #     print(e.__class__.__name__)
            #     if hasattr(e, 'env'):
            #         print_rec(e.env)
            # import time
            # import random
            # time.sleep(random.random() * 10)
            # print('\tSHOWING STUFF')
            # print_rec(env)
            # print('\n\n\n')
            return env
        return env_fn

    nenvs = args.nenvs
    env = SubprocVecEnv([make_env(i + nenvs * hvd.rank()) for i in range(nenvs)])
    env = VecFrameStack(env, 1 if 'fetch' in args.game else 4)

    if 'fetch' in args.game:
        print('Fetch environment, using the feedforward policy.')
        args.policy = FFPolicy
    else:
        args.policy = {'cnn': CnnPolicy, 'gru': GRUPolicy}[args.policy]

    args.sil_pg_weight_by_value = False
    args.sil_vf_relu = False
    args.sil_vf_coef = 0
    args.sil_coef = 0
    args.sil_ent_coef = 0
    args.ent_coef = 0
    args.vf_coef = 0
    args.cliprange = 1
    args.l2_coef = 0
    args.adam_epsilon = 1e-8
    args.gamma = 0.99
    args.lam = 0.10
    args.scale_rewards = 1.0
    args.sil_weight_success_rate = True
    args.norm_adv = 1.0
    args.log_interval = 1
    args.save_interval = 100
    args.subtract_rew_avg = True
    args.clip_rewards = False
    learn(env, args, True)