Example #1
0
def load(bsuite_id: Text,
         record: bool = True,
         save_path: Optional[Text] = None,
         logging_mode: Text = 'csv',
         overwrite: bool = False) -> py_environment.PyEnvironment:
    """Loads the selected environment.

  Args:
    bsuite_id: a bsuite_id specifies a bsuite experiment. For an example
      `bsuite_id` "deep_sea/7" will be 7th level of the "deep_sea" task.
    record: whether to log bsuite results.
    save_path: the directory to save bsuite results.
    logging_mode: which form of logging to use for bsuite results
      ['csv', 'sqlite', 'terminal'].
    overwrite: overwrite csv logging if found.

  Returns:
    A PyEnvironment instance.
  """
    if record:
        raw_env = bsuite.load_and_record(bsuite_id=bsuite_id,
                                         save_path=save_path,
                                         logging_mode=logging_mode,
                                         overwrite=overwrite)
    else:
        raw_env = bsuite.load_from_id(bsuite_id=bsuite_id)
    gym_env = gym_wrapper.GymFromDMEnv(raw_env)
    return suite_gym.wrap_env(gym_env)
Example #2
0
    def __init__(self,
                 env_id,
                 agent,
                 verbose=True,
                 log_interval=100,
                 eval=False):
        '''
        PARAMETERS:
        'env_id'       - Environment ID eg: environments.CARTPOLE
        'agent'        - Instance of an Agent class with the necessary methods implemented
        'verbose'      - True: prints logs, False: doesn't print logs
        'log_interval' - Interval between episodes to print logs at
        'eval'         - Use custom private results path as results dir
        '''
        self.agent = agent
        self.env_id = env_id
        self.verbose = verbose
        self.log_interval = log_interval

        if (eval):
            results_dir = os.environ.get('PRIVATE_RESULTS_DIR')
        else:
            results_dir = os.environ.get('RESULTS_DIR')

        env = bsuite.load_and_record_to_csv(env_id,
                                            results_dir=results_dir,
                                            overwrite=True)
        self.env = gym_wrapper.GymFromDMEnv(env)
Example #3
0
def experiment_factory(opt, only_env=False):
    env = gym_wrapper.GymFromDMEnv(bsuite.load_from_id(opt.env.name))
    env = TorchWrapper(env, opt.device)
    if only_env:
        return env

    replay = ExperienceReplay(**opt.replay)
    layers = [
        reduce(lambda x, y: x * y, env.observation_space.shape),  # input
        *opt.estimator["layers"],  # hidden
        env.action_space.n,  # output
    ]
    estimator = MLP(layers, spectral=opt.spectral, **opt.estimator)
    estimator.to(opt.device)

    optimizer = getattr(torch.optim, opt.optim.name)(
        estimator.parameters(), **opt.optim.kwargs
    )
    policy_improvement = C51PolicyImprovement(
        estimator, opt.epsilon, env.action_space.n
    )
    policy_evaluation = C51PolicyEvaluation(estimator, optimizer, opt.gamma)
    rlog.info(replay)
    rlog.info(estimator)
    return env, (replay, policy_improvement, policy_evaluation)
Example #4
0
 def create_environment() -> gym.Env:
     """Factory method for environment initialization in Dopmamine."""
     env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE)
     if FLAGS.verbose:
         env = terminal_logging.wrap_environment(env, log_every=True)  # pytype: disable=wrong-arg-types
     env = gym_wrapper.GymFromDMEnv(env)
     env.game_over = False  # Dopamine looks for this
     return env
Example #5
0
 def _load_env():
     raw_env = bsuite.load_and_record(
         bsuite_id=bsuite_id,
         save_path=FLAGS.save_path,
         logging_mode=FLAGS.logging_mode,
         overwrite=FLAGS.overwrite,
     )
     if FLAGS.verbose:
         raw_env = terminal_logging.wrap_environment(raw_env,
                                                     log_every=True)
     return gym_wrapper.GymFromDMEnv(raw_env)
Example #6
0
 def get_env(*args, **kwargs):
     return GymEnvWrapper(
         TransformObservation(env=FrameStack(
             num_stack=4,
             env=(gym_wrapper.GymFromDMEnv(
                 bsuite.load_and_record_to_csv(
                     bsuite_id=bsuite_id,
                     results_dir=results_dir,
                     overwrite=True,
                 )) if not gym_id else gym.make(gym_id))),
                              f=lambda lazy_frames: np.reshape(
                                  np.stack(lazy_frames._frames), -1)))
Example #7
0
 def create_environment() -> gym.Env:
     """Factory method for environment initialization in Dopmamine."""
     env = bsuite.load_and_record(
         bsuite_id=bsuite_id,
         save_path=FLAGS.save_path,
         logging_mode=FLAGS.logging_mode,
         overwrite=FLAGS.overwrite,
     )
     env = wrappers.ImageObservation(env, OBSERVATION_SHAPE)
     if FLAGS.verbose:
         env = terminal_logging.wrap_environment(env, log_every=True)
     env = gym_wrapper.GymFromDMEnv(env)
     env.game_over = False  # Dopamine looks for this
     return env
Example #8
0
def run(bsuite_id: str) -> str:
    """Runs a DQN agent on a given bsuite environment, logging to CSV."""

    raw_env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )
    if FLAGS.verbose:
        raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)  # pytype: disable=wrong-arg-types
    env = gym_wrapper.GymFromDMEnv(raw_env)

    num_episodes = FLAGS.num_episodes or getattr(raw_env,
                                                 'bsuite_num_episodes')

    def callback(lcl, unused_glb):
        # Terminate after `num_episodes`.
        try:
            return lcl['num_episodes'] > num_episodes
        except KeyError:
            return False

    # Note: we should never run for this many steps as we end after `num_episodes`
    total_timesteps = FLAGS.total_timesteps

    deepq.learn(
        env=env,
        network='mlp',
        hiddens=[FLAGS.num_units] * FLAGS.num_hidden_layers,
        batch_size=FLAGS.batch_size,
        lr=FLAGS.learning_rate,
        total_timesteps=total_timesteps,
        buffer_size=FLAGS.replay_capacity,
        exploration_fraction=1. / total_timesteps,  # i.e. immediately anneal.
        exploration_final_eps=FLAGS.epsilon,  # constant epsilon.
        print_freq=None,  # pylint: disable=wrong-arg-types
        learning_starts=FLAGS.min_replay_size,
        target_network_update_freq=FLAGS.target_update_period,
        callback=callback,  # pytype: disable=wrong-arg-types
        gamma=FLAGS.agent_discount,
        checkpoint_freq=None,
    )

    return bsuite_id
Example #9
0
        def _thunk():
            random_seed(seed)
            if env_id.startswith('bsuite'):
                id = env_id.split('bsuite-')[1]
                self.video_enabled = False
                bsuite_env = bsuite.load_from_id(id)
                env = gym_wrapper.GymFromDMEnv(bsuite_env)

            elif env_id.startswith("dm"):
                import dm_control2gym
                _, domain, task = env_id.split('-')
                env = dm_control2gym.make(domain_name=domain, task_name=task)

            else:
                if special_args is not None:
                    if 'NChain' in special_args[0]:
                        print('starting chain N = ', special_args[1])
                        env = gym.make(env_id, n=special_args[1])
                else:
                    env = gym.make(env_id)

            if self.video_enabled:
                env = Monitor(env,
                              self.log_dir,
                              video_callable=self.video_callable)

            is_atari = hasattr(gym.envs, 'atari') and isinstance(
                env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
            if is_atari:
                env = make_atari(env_id)
            env.seed(seed + rank)
            env = OriginalReturnWrapper(env)
            if is_atari:
                env = wrap_deepmind(env,
                                    episode_life=episode_life,
                                    clip_rewards=False,
                                    frame_stack=False,
                                    scale=False)
                obs_shape = env.observation_space.shape
                if len(obs_shape) == 3:
                    env = TransposeImage(env)
                env = FrameStack(env, 4)
            return env
Example #10
0
def load_env(env_id):
    env = bsuite.load_from_id(env_id)
    env = gym_wrapper.GymFromDMEnv(env)
    return env
Example #11
0
    #envs = sweep.BANDIT
    envs = ["bandit/0"]

    for bsuite_id in envs:
        b_env = 'bandit'
        env_plot_path = Path(plot_dir + bsuite_id.replace("/", "-") + "/")
        env_plot_path.mkdir(parents=True, exist_ok=True)
        env_plot_path = str(env_plot_path.resolve())

        args = get_args()

        # Initialize the environment
        bsuite_env = load_and_record_to_csv(bsuite_id,
                                            results_dir=csv_dir,
                                            overwrite=True)
        gym_env = gym_wrapper.GymFromDMEnv(bsuite_env)
        env = GymEnv(gym_env)
        env_builder = lambda: env

        algo = setup_test(args, env)

        off_policy_trainer = OffPolicyTrainer()
        off_policy_trainer.train(args, env_builder, algo)

        # Analyze performance
        df, sweep_vars = csv_load.load_bsuite(csv_dir)

        bandit_df = df[df.bsuite_env == b_env].copy()

        bsuite_score = summary_analysis.bsuite_score(df, sweep_vars)
        bsuite_summary = summary_analysis.ave_score_by_tag(