コード例 #1
0
        def env_fn():
            if args.game == 'fetch':
                assert args.fetch_target_location is not None, 'For now, we require a target location for fetch'
                kwargs = {}
                dargs = vars(args)
                for attr in dargs:
                    if attr.startswith('fetch_'):
                        if attr == 'fetch_type':
                            kwargs['model_file'] = f'teleOp_{args.fetch_type}.xml'
                        elif attr != 'fetch_total_timestep':
                            kwargs[attr[len('fetch_'):]] = dargs[attr]

                env = goexplore_py.complex_fetch_env.ComplexFetchEnv(
                    **kwargs
                )
            elif args.game == 'fetch_dumb':
                env = goexplore_py.dumb_fetch_env.ComplexFetchEnv()
            else:
                env = gym.make(args.game + 'NoFrameskip-v4')
                if args.seed_env:
                    env.seed(0)
                # if args.unlimited_score:
                #     # This removes the TimeLimit wrapper around the env
                #     env = env.env
                # env = PreventSlugEnv(env)
            # change for long runs
            # env._max_episode_steps *= 1000
            env = bench.Monitor(env, "{}.monitor.json".format(rank), allow_early_resets=True)
            if False and rank%nenvs == 0 and hvd.local_rank()==0:
                os.makedirs(args.save_path + '/vids/' + args.game, exist_ok=True)
                videofile_prefix = args.save_path + '/vids/' + args.game
                env = VideoWriter(env, videofile_prefix)
            if 'fetch' not in args.game:
                if args.noops:
                    os.makedirs(args.save_path, exist_ok=True)
                    env = NoopResetEnv(env, 30, nenvs, args.save_path, num_per_noop=args.num_per_noop, unlimited_score=args.unlimited_score)
                    env = my_wrapper(env, clip_rewards=True, sticky=args.sticky)
                if args.epsgreedy:
                    env = EpsGreedyEnv(env)
            else:
                os.makedirs(f'{args.save_path}', exist_ok=True)
                env = FetchSaveEnv(env, rank=rank, n_ranks=nenvs, save_path=f'{args.save_path}/', demo_path=args.demo)
                env = TanhWrap(env)
            # def print_rec(e):
            #     print(e.__class__.__name__)
            #     if hasattr(e, 'env'):
            #         print_rec(e.env)
            # import time
            # import random
            # time.sleep(random.random() * 10)
            # print('\tSHOWING STUFF')
            # print_rec(env)
            # print('\n\n\n')
            return env
コード例 #2
0
 def env_fn():
     env = gym.make(game_name + 'NoFrameskip-v4')
     env = ReplayResetEnv(env,
                          demo_file_name='demos/' + game_name + '.demo',
                          seed=rank,
                          workers_per_sp=workers_per_sp)
     if rank % nenvs == 0 and hvd.local_rank(
     ) == 0:  # write videos during training to track progress
         dir = os.path.join(save_path, game_name)
         os.makedirs(dir, exist_ok=True)
         videofile_prefix = os.path.join(dir, 'episode')
         env = VideoWriter(env, videofile_prefix)
     env = my_wrapper(env, clip_rewards=True)
     return env
コード例 #3
0
 def env_fn():
     env = gym.make(game_name + 'NoFrameskip-v4')
     env = bench.Monitor(env, "{}.monitor.json".format(rank))
     if rank % nenvs == 0 and hvd.local_rank() == 0:
         os.makedirs('results/' + game_name, exist_ok=True)
         videofile_prefix = 'results/' + game_name
         env = VideoWriter(env, videofile_prefix)
     if noops:
         env = NoopResetEnv(env)
     if sticky:
         env = StickyActionEnv(env)
     env = my_wrapper(env, clip_rewards=True)
     if epsgreedy:
         env = EpsGreedyEnv(env)
     return env
コード例 #4
0
ファイル: train_atari.py プロジェクト: o7s8r6/atari-reset
        def env_fn():
            if args.game == 'fetch':
                assert args.fetch_target_location is not None, 'For now, we require a target location for fetch'
                kwargs = {}
                dargs = vars(args)
                for attr in dargs:
                    if attr.startswith('fetch_'):
                        if attr == 'fetch_type':
                            kwargs[
                                'model_file'] = f'teleOp_{args.fetch_type}.xml'
                        elif attr != 'fetch_total_timestep':
                            kwargs[attr[len('fetch_'):]] = dargs[attr]

                env = goexplore_py.complex_fetch_env.ComplexFetchEnv(**kwargs)
            elif args.game == 'fetch_dumb':
                env = goexplore_py.dumb_fetch_env.ComplexFetchEnv(
                    incl_extra_full_state=args.fetch_incl_extra_full_state)
            else:
                env = gym.make(args.game + 'NoFrameskip-v4')
            env = ReplayResetEnv(env,
                                 args,
                                 seed=rank,
                                 workers_per_sp=workers_per_sp,
                                 is_extra_sil=is_extra_sil,
                                 frameskip=frameskip)
            if 'fetch' not in args.game:
                if rank % args.nenvs == 0 and hvd.local_rank(
                ) == 0:  # write videos during training to track progress
                    dir = os.path.join(args.save_path, args.game)
                    os.makedirs(dir, exist_ok=True)
                    if args.videos:
                        videofile_prefix = os.path.join(dir, 'episode')
                        env = VideoWriter(env, videofile_prefix)
                env = my_wrapper(
                    env,
                    #  clip_rewards=args.clip_rewards,
                    frame_resize_wrapper=frame_resize_wrapper,
                    #  scale_rewards=args.scale_rewards,
                    sticky=args.sticky)
            else:
                env = TanhWrap(env)
            return env