def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
    wrapper_kwargs = wrapper_kwargs or {}
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
예제 #2
0
 def _thunk():
     env = make_atari(env_id)
     env.seed(seed + rank)
     env = Monitor(
         env,
         logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
     return wrap_deepmind(env, **wrapper_kwargs)
예제 #3
0
    def make_env(self, rank):
        env = make_atari(self.config.env_name)
        env.seed(self.config.seed + rank)
        gym.logger.setLevel(logger.WARN)
        env = wrap_deepmind(env)

        # wrap the env one more time for getting total reward
        env = Monitor(env, rank)
        return env
예제 #4
0
def main():
    """
    Run the atari test
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--num-timesteps', type=int, default=int(1e7))

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    policy = partial(CnnPolicy, dueling=args.dueling == 1)

    # model = DQN(
    #     env=env,
    #     policy=policy,
    #     learning_rate=1e-4,
    #     buffer_size=10000,
    #     exploration_fraction=0.1,
    #     exploration_final_eps=0.01,
    #     train_freq=4,
    #     learning_starts=10000,
    #     target_network_update_freq=1000,
    #     gamma=0.99,
    #     prioritized_replay=bool(args.prioritized),
    #     prioritized_replay_alpha=args.prioritized_replay_alpha,
    # )
    model = DQN(
        env=env,
        policy_class=CnnPolicy,
        learning_rate=1e-4,
        buffer_size=10000,
        double_q=False,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        dueling=True,
        train_freq=4,
        learning_starts=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        target_network_update_freq=1000,
        model_path='atari_Breakout_duel'
    )
    # model.learn(total_timesteps=args.num_timesteps, seed=args.seed)
    model.load('atari_Breakout_duel')
    model.evaluate(100)
    env.close()
예제 #5
0
def env_agent_config(cfg, seed=1):
    ''' 创建环境和智能体
    '''
    env = make_atari(cfg.env_name)  # 创建环境
    # env    = wrap_deepmind(env)
    # env    = wrap_pytorch(env)
    env.seed(seed)  # 设置随机种子
    state_dim = env.observation_space.shape[0]  # 状态维度
    action_dim = env.action_space.n  # 动作维度
    agent = DQN(state_dim, action_dim, cfg)  # 创建智能体
    return env, agent
예제 #6
0
        def _thunk():
            env = make_atari(env_id) if env_type == 'atari' else gym.make(
                env_id)
            env.seed(seed + 10000 * mpi_rank +
                     rank if seed is not None else None)
            env = Monitor(env,
                          logger.get_dir()
                          and os.path.join(logger.get_dir(),
                                           str(mpi_rank) + '.' + str(rank)),
                          allow_early_resets=True)

            if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs)
            elif reward_scale != 1: return RewardScaler(env, reward_scale)
            else: return env
예제 #7
0
def create_gvgai_environment(env_id):
    from common.atari_wrappers import wrap_deepmind, make_atari, ActionDirectionEnv
    initial_direction = {'gvgai-testgame1': 3, 'gvgai-testgame2': 3}
    logger.configure()
    game_name = env_id.split('-lvl')[0]
    does_need_action_direction = False

    # Environment creation
    env = make_atari(env_id)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_deepmind(env, episode_life=False, clip_rewards=False, frame_stack=False, scale=True)
    if game_name in initial_direction:
        print("We should model with action direction")
        env = ActionDirectionEnv(env, initial_direction=initial_direction[game_name])
        does_need_action_direction = True
    return env, does_need_action_direction, game_name
예제 #8
0
def main():
    """
    Run the atari test
    """
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--num-timesteps', type=int, default=int(1e7))

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)

    env = make_atari(args.env)
    env.action_space.seed(args.seed)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)

    model = DQN(env=env,
                policy_class=CnnPolicy,
                buffer_size=10000,
                learning_rate=1e-4,
                learning_starts=10000,
                target_network_update_freq=1000,
                train_freq=4,
                exploration_final_eps=0.01,
                exploration_fraction=0.1,
                prioritized_replay=True,
                model_path='atari_test_Breakout')
    model.learn(total_timesteps=args.num_timesteps)
    env.close()
예제 #9
0
def testbaselines(args):
    # 用baseline对环境的包装方法,尝试运行最简代码
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    #extra_args = parse_cmdline_kwargs(unknown_args)

    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))
    env = make_atari(env_id)
    #env = build_env(args)
    print("env builded ",env)
    obs = env.reset()
    reset = True

    print("env reseted")
    for t in range(10000):
        env.render()
        action = env.action_space.sample()
        new_obs, rew, done, _ = env.step(action)
        obs = new_obs
        if done:
            print(done)
            obs = env.reset()
            reset = True
예제 #10
0
def make_env(env_id,
             env_type,
             mpi_rank=0,
             subrank=0,
             seed=None,
             reward_scale=1.0,
             gamestate=None,
             flatten_dict_observations=True,
             wrapper_kwargs=None,
             logger_dir=None,
             cloth_cfg_path=None,
             render_path=None,
             start_state_path=None):
    """Daniel: make single instance of env, to be wrapped in VecEnv for parallelism.

    We need to have a special if case for the clothenv, which doesn't actually
    use `gym.make(...)` because we have a custom configuration.
    """
    wrapper_kwargs = wrapper_kwargs or {}

    if env_type == 'cloth':
        print("Env Type is Cloth")
        assert cloth_cfg_path is not None
        from gym_cloth.envs import ClothEnv
        env = ClothEnv(cloth_cfg_path,
                       subrank=subrank,
                       start_state_path=start_state_path)
        print('Created ClothEnv, seed {}, mpi_rank {}, subrank {}.'.format(
            seed, mpi_rank, subrank))
        print('start_state_path: {}'.format(start_state_path))
        # Daniel: render, but currently only works if we have one env, not a vec ...
        if render_path is not None:
            env.render(filepath=render_path)
    elif env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(
            game=env_id,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE,
            state=gamestate)
    else:
        print("USING WRONG COMMAND")
        env = gym.make(env_id)

    if flatten_dict_observations and isinstance(env.observation_space,
                                                gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir
                  and os.path.join(logger_dir,
                                   str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    #Adi: Let's return the actual env for now instead of the wrapped version for simplicity.  Can change this back later.
    env = env.unwrapped

    return env
예제 #11
0
def main(args):
    # mpi communicator.
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    # seed.
    workerseed = args.seed + 10000 * comm.Get_rank() if args.seed is not None else None
    if workerseed is not None:
        tc.manual_seed(workerseed % 2 ** 32)
        np.random.seed(workerseed % 2 ** 32)
        random.seed(workerseed % 2 ** 32)

    # logger.
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])

    # env.
    env = make_atari(args.env_name)
    env.seed(workerseed)
    env = Monitor(env, logger.get_dir() and
              os.path.join(logger.get_dir(), str(rank)))
    print(f"frame_stacking: {args.frame_stacking}")
    env = wrap_deepmind(env, frame_stack=args.frame_stacking,
                        clip_rewards=(args.mode =='train'),
                        episode_life=(args.mode =='train'))  # See Mnih et al., 2015 -> Methods -> Training Details.
    env.seed(workerseed)

    # agent.
    agent = CnnPolicy(
        img_channels=env.observation_space.shape[-1],
        num_actions=env.action_space.n,
        kind=args.model_type)

    # optimizer and scheduler.
    max_grad_steps = args.optim_epochs * args.env_steps // (comm.Get_size() * args.optim_batchsize)

    optimizer = tc.optim.Adam(agent.parameters(), lr=args.optim_stepsize, eps=1e-5)
    scheduler = tc.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer, max_lr=args.optim_stepsize, total_steps=max_grad_steps,
        pct_start=0.0, anneal_strategy='linear', cycle_momentum=False,
        div_factor=1.0)

    # checkpoint.
    if rank == 0:
        try:
            state_dict = tc.load(os.path.join(args.checkpoint_dir, args.model_name, 'model.pth'))
            agent.load_state_dict(state_dict)
            print(f"Continuing from checkpoint found at {os.path.join(args.checkpoint_dir, args.model_name, 'model.pth')}")
        except FileNotFoundError:
            print("Bad checkpoint or none on process 0. Continuing from scratch.")

    # sync.
    with tc.no_grad():
        for p in agent.parameters():
            p_data = p.data.numpy()
            comm.Bcast(p_data, root=0)
            p.data.copy_(tc.tensor(p_data).float())

    # operations.
    if args.mode == 'train':
        learn(env=env, agent=agent, optimizer=optimizer, scheduler=scheduler, comm=comm,
              timesteps_per_actorbatch=args.timesteps_per_actorbatch, max_timesteps=args.env_steps,
              optim_epochs=args.optim_epochs, optim_batchsize=args.optim_batchsize,
              gamma=args.gamma, lam=args.lam, clip_param=args.epsilon, entcoeff=args.ent_coef,
              checkpoint_dir=args.checkpoint_dir, model_name=args.model_name)
        env.close()

    elif args.mode == 'play':
        if comm.Get_rank() == 0:
            play(env=env, agent=agent, args=args)
            env.close()

    elif args.mode == 'movie':
        if comm.Get_rank() == 0:
            movie(env=env, agent=agent, args=args)
            env.close()

    else:
        raise NotImplementedError("Mode of operation not supported!")
예제 #12
0
def make_env(env_id,
             env_type,
             mpi_rank=0,
             subrank=0,
             seed=None,
             reward_scale=1.0,
             gamestate=None,
             flatten_dict_observations=True,
             wrapper_kwargs=None,
             env_kwargs=None,
             logger_dir=None,
             initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*', '', env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(
            game=env_id,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE,
            state=gamestate)
    else:
        env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space,
                                                gym.spaces.Dict):
        env = FlattenObservation(env)

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir
                  and os.path.join(logger_dir,
                                   str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
예제 #13
0
def make_env(env_id, seed=None, max_episode_steps=None, wrapper_kwargs=None):
    wrapper_kwargs = wrapper_kwargs or {}
    env = make_atari(env_id, max_episode_steps)
    env.seed(seed)
    env = wrap_deepmind(env, **wrapper_kwargs)
    return env
예제 #14
0
import common.atari_wrappers as wrappers
import gym

if __name__ == '__main__':
    # env = gym.make("SpaceInvaders-v0")
    env = wrappers.make_atari("SpaceInvaders-v0", 1000)
    env.reset()
    for episode in range(1000):
        env.reset()
        for t in range(1000):
            env.render()
            observation, reward, done, info = env.step(
                env.action_space.sample())
            if done:
                print("episode {} ends after {} steps.".format(
                    episode + 1, t + 1))
                print("         left live: {}".format(info))
                break