Пример #1
0
    def __init__(self,
                 registry,
                 env_creator,
                 config,
                 logdir,
                 start_sampler=True):
        env = ModelCatalog.get_preprocessor_as_wrapper(
            registry, env_creator(config["env_config"]), config["model"])
        self.env = env
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(registry, env.observation_space.shape,
                                 env.action_space, config)
        self.config = config

        # Technically not needed when not remote
        self.obs_filter = get_filter(config["observation_filter"],
                                     env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.filters = {
            "obs_filter": self.obs_filter,
            "rew_filter": self.rew_filter
        }
        self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
                                    config["batch_size"])
        if start_sampler and self.sampler. async:
            self.sampler.start()
        self.logdir = logdir
Пример #2
0
 def __init__(self, env_creator, config, logdir):
     self.env = env = create_and_wrap(env_creator, config["model"])
     policy_cls = get_policy_cls(config)
     # TODO(rliaw): should change this to be just env.observation_space
     self.policy = policy_cls(env.observation_space.shape, env.action_space)
     obs_filter = get_filter(config["observation_filter"],
                             env.observation_space.shape)
     self.rew_filter = get_filter(config["reward_filter"], ())
     self.sampler = AsyncSampler(env, self.policy, obs_filter,
                                 config["batch_size"])
     self.logdir = logdir
Пример #3
0
    def __init__(self, env_creator, config, logdir, start_sampler=True):
        self.env = env = create_and_wrap(env_creator, config["preprocessing"])
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(env.observation_space.shape, env.action_space,
                                 config)
        self.config = config

        # Technically not needed when not remote
        self.obs_filter = get_filter(config["observation_filter"],
                                     env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
                                    config["batch_size"])
        if start_sampler and self.sampler. async:
            self.sampler.start()
        self.logdir = logdir
Пример #4
0
    def __init__(
            self, registry, env_creator, config, logdir, start_sampler=True):
        env = ModelCatalog.get_preprocessor_as_wrapper(
            registry, env_creator(config["env_config"]), config["model"])
        self.env = env
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(
            registry, env.observation_space.shape, env.action_space, config)
        self.config = config

        # Technically not needed when not remote
        self.obs_filter = get_filter(
            config["observation_filter"], env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.filters = {"obs_filter": self.obs_filter,
                        "rew_filter": self.rew_filter}
        self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
                                    config["batch_size"])
        if start_sampler and self.sampler.async:
            self.sampler.start()
        self.logdir = logdir
Пример #5
0
    def __init__(self,
                 env_creator,
                 policy_graph,
                 tf_session_creator=None,
                 batch_steps=100,
                 batch_mode="truncate_episodes",
                 episode_horizon=None,
                 preprocessor_pref="rllib",
                 sample_async=False,
                 compress_observations=False,
                 num_envs=1,
                 observation_filter="NoFilter",
                 env_config=None,
                 model_config=None,
                 policy_config=None):
        """Initialize a policy evaluator.

        Arguments:
            env_creator (func): Function that returns a gym.Env given an
                env config dict.
            policy_graph (class): A class implementing rllib.PolicyGraph or
                rllib.TFPolicyGraph.
            tf_session_creator (func): A function that returns a TF session.
                This is optional and only useful with TFPolicyGraph.
            batch_steps (int): The target number of env transitions to include
                in each sample batch returned from this evaluator.
            batch_mode (str): One of the following batch modes:
                "truncate_episodes": Each call to sample() will return a batch
                    of exactly `batch_steps` in size. Episodes may be truncated
                    in order to meet this size requirement. When
                    `num_envs > 1`, episodes will be truncated to sequences of
                    `batch_size / num_envs` in length.
                "complete_episodes": Each call to sample() will return a batch
                    of at least `batch_steps in size. Episodes will not be
                    truncated, but multiple episodes may be packed within one
                    batch to meet the batch size. Note that when
                    `num_envs > 1`, episode steps will be buffered until the
                    episode completes, and hence batches may contain
                    significant amounts of off-policy data.
            episode_horizon (int): Whether to stop episodes at this horizon.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations
                returned.
            num_envs (int): If more than one, will create multiple envs
                and vectorize the computation of actions. This has no effect if
                if the env already implements VectorEnv.
            observation_filter (str): Name of observation filter to use.
            env_config (dict): Config to pass to the env creator.
            model_config (dict): Config to use when creating the policy model.
            policy_config (dict): Config to pass to the policy.
        """

        env_config = env_config or {}
        policy_config = policy_config or {}
        model_config = model_config or {}
        self.env_creator = env_creator
        self.policy_graph = policy_graph
        self.batch_steps = batch_steps
        self.batch_mode = batch_mode
        self.compress_observations = compress_observations

        self.env = env_creator(env_config)
        if isinstance(self.env, VectorEnv) or \
                isinstance(self.env, ServingEnv) or \
                isinstance(self.env, AsyncVectorEnv):

            def wrap(env):
                return env  # we can't auto-wrap these env types
        elif is_atari(self.env) and \
                "custom_preprocessor" not in model_config and \
                preprocessor_pref == "deepmind":

            def wrap(env):
                return wrap_deepmind(env, dim=model_config.get("dim", 80))
        else:

            def wrap(env):
                return ModelCatalog.get_preprocessor_as_wrapper(
                    env, model_config)

        self.env = wrap(self.env)

        def make_env():
            return wrap(env_creator(env_config))

        self.policy_map = {}

        if issubclass(policy_graph, TFPolicyGraph):
            with tf.Graph().as_default():
                if tf_session_creator:
                    self.sess = tf_session_creator()
                else:
                    self.sess = tf.Session(config=tf.ConfigProto(
                        gpu_options=tf.GPUOptions(allow_growth=True)))
                with self.sess.as_default():
                    policy = policy_graph(self.env.observation_space,
                                          self.env.action_space, policy_config)
        else:
            policy = policy_graph(self.env.observation_space,
                                  self.env.action_space, policy_config)
        self.policy_map = {"default": policy}

        self.obs_filter = get_filter(observation_filter,
                                     self.env.observation_space.shape)
        self.filters = {"obs_filter": self.obs_filter}

        # Always use vector env for consistency even if num_envs = 1
        if not isinstance(self.env, AsyncVectorEnv):
            if isinstance(self.env, ServingEnv):
                self.vector_env = _ServingEnvToAsync(self.env)
            else:
                if not isinstance(self.env, VectorEnv):
                    self.env = VectorEnv.wrap(make_env, [self.env],
                                              num_envs=num_envs)
                self.vector_env = _VectorEnvToAsync(self.env)
        else:
            self.vector_env = self.env

        if self.batch_mode == "truncate_episodes":
            if batch_steps % num_envs != 0:
                raise ValueError(
                    "In 'truncate_episodes' batch mode, `batch_steps` must be "
                    "evenly divisible by `num_envs`. Got {} and {}.".format(
                        batch_steps, num_envs))
            batch_steps = batch_steps // num_envs
            pack_episodes = True
        elif self.batch_mode == "complete_episodes":
            batch_steps = float("inf")  # never cut episodes
            pack_episodes = False  # sampler will return 1 episode per poll
        else:
            raise ValueError("Unsupported batch mode: {}".format(
                self.batch_mode))
        if sample_async:
            self.sampler = AsyncSampler(self.vector_env,
                                        self.policy_map["default"],
                                        self.obs_filter,
                                        batch_steps,
                                        horizon=episode_horizon,
                                        pack=pack_episodes)
            self.sampler.start()
        else:
            self.sampler = SyncSampler(self.vector_env,
                                       self.policy_map["default"],
                                       self.obs_filter,
                                       batch_steps,
                                       horizon=episode_horizon,
                                       pack=pack_episodes)
Пример #6
0
class A3CEvaluator(Evaluator):
    """Actor object to start running simulation on workers.

    The gradient computation is also executed from this object.

    Attributes:
        policy: Copy of graph used for policy. Used by sampler and gradients.
        obs_filter: Observation filter used in environment sampling
        rew_filter: Reward filter used in rollout post-processing.
        sampler: Component for interacting with environment and generating
            rollouts.
        logdir: Directory for logging.
    """
    def __init__(
            self, registry, env_creator, config, logdir, start_sampler=True):
        env = ModelCatalog.get_preprocessor_as_wrapper(
            registry, env_creator(config["env_config"]), config["model"])
        self.env = env
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(
            registry, env.observation_space.shape, env.action_space, config)
        self.config = config

        # Technically not needed when not remote
        self.obs_filter = get_filter(
            config["observation_filter"], env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.filters = {"obs_filter": self.obs_filter,
                        "rew_filter": self.rew_filter}
        self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
                                    config["batch_size"])
        if start_sampler and self.sampler.async:
            self.sampler.start()
        self.logdir = logdir

    def sample(self):
        rollout = self.sampler.get_data()
        samples = process_rollout(
            rollout, self.rew_filter, gamma=self.config["gamma"],
            lambda_=self.config["lambda"], use_gae=True)
        return samples

    def get_completed_rollout_metrics(self):
        """Returns metrics on previously completed rollouts.

        Calling this clears the queue of completed rollout metrics.
        """
        return self.sampler.get_metrics()

    def compute_gradients(self, samples):
        gradient, info = self.policy.compute_gradients(samples)
        return gradient

    def apply_gradients(self, grads):
        self.policy.apply_gradients(grads)

    def get_weights(self):
        return self.policy.get_weights()

    def set_weights(self, params):
        self.policy.set_weights(params)

    def save(self):
        filters = self.get_filters(flush_after=True)
        weights = self.get_weights()
        return pickle.dumps({
            "filters": filters,
            "weights": weights})

    def restore(self, objs):
        objs = pickle.loads(objs)
        self.sync_filters(objs["filters"])
        self.set_weights(objs["weights"])

    def sync_filters(self, new_filters):
        """Changes self's filter to given and rebases any accumulated delta.

        Args:
            new_filters (dict): Filters with new state to update local copy.
        """
        assert all(k in new_filters for k in self.filters)
        for k in self.filters:
            self.filters[k].sync(new_filters[k])

    def get_filters(self, flush_after=False):
        """Returns a snapshot of filters.

        Args:
            flush_after (bool): Clears the filter buffer state.

        Returns:
            return_filters (dict): Dict for serializable filters
        """
        return_filters = {}
        for k, f in self.filters.items():
            return_filters[k] = f.as_serializable()
            if flush_after:
                f.clear_buffer()
        return return_filters
Пример #7
0
class A3CEvaluator(Evaluator):
    """Actor object to start running simulation on workers.

    The gradient computation is also executed from this object.

    Attributes:
        policy: Copy of graph used for policy. Used by sampler and gradients.
        rew_filter: Reward filter used in rollout post-processing.
        sampler: Component for interacting with environment and generating
            rollouts.
        logdir: Directory for logging.
    """
    def __init__(self, env_creator, config, logdir, start_sampler=True):
        self.env = env = create_and_wrap(env_creator, config["preprocessing"])
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(env.observation_space.shape, env.action_space,
                                 config)
        self.config = config

        # Technically not needed when not remote
        self.obs_filter = get_filter(config["observation_filter"],
                                     env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
                                    config["batch_size"])
        if start_sampler and self.sampler. async:
            self.sampler.start()
        self.logdir = logdir

    def sample(self):
        """
        Returns:
            trajectory (PartialRollout): Experience Samples from evaluator"""
        rollout = self.sampler.get_data()
        samples = process_rollout(rollout,
                                  self.rew_filter,
                                  gamma=self.config["gamma"],
                                  lambda_=self.config["lambda"],
                                  use_gae=True)
        return samples

    def get_completed_rollout_metrics(self):
        """Returns metrics on previously completed rollouts.

        Calling this clears the queue of completed rollout metrics.
        """
        return self.sampler.get_metrics()

    def compute_gradients(self, samples):
        gradient, info = self.policy.compute_gradients(samples)
        return gradient

    def apply_gradients(self, grads):
        self.policy.apply_gradients(grads)

    def get_weights(self):
        return self.policy.get_weights()

    def set_weights(self, params):
        self.policy.set_weights(params)

    def update_filters(self, obs_filter=None, rew_filter=None):
        if rew_filter:
            # No special handling required since outside of threaded code
            self.rew_filter = rew_filter.copy()
        if obs_filter:
            self.sampler.update_obs_filter(obs_filter)

    def save(self):
        weights = self.get_weights()
        return pickle.dumps({"weights": weights})

    def restore(self, objs):
        objs = pickle.loads(objs)
        self.set_weights(objs["weights"])
Пример #8
0
class A3CEvaluator(Evaluator):
    """Actor object to start running simulation on workers.

    The gradient computation is also executed from this object.

    Attributes:
        policy: Copy of graph used for policy. Used by sampler and gradients.
        rew_filter: Reward filter used in rollout post-processing.
        sampler: Component for interacting with environment and generating
            rollouts.
        logdir: Directory for logging.
    """
    def __init__(self, env_creator, config, logdir):
        self.env = env = create_and_wrap(env_creator, config["model"])
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(env.observation_space.shape, env.action_space)
        obs_filter = get_filter(config["observation_filter"],
                                env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.sampler = AsyncSampler(env, self.policy, obs_filter,
                                    config["batch_size"])
        self.logdir = logdir

    def sample(self):
        """
        Returns:
            trajectory (PartialRollout): Experience Samples from evaluator"""
        rollout = self.sampler.get_data()
        return rollout

    def get_completed_rollout_metrics(self):
        """Returns metrics on previously completed rollouts.

        Calling this clears the queue of completed rollout metrics.
        """
        return self.sampler.get_metrics()

    def compute_gradient(self):
        rollout = self.sampler.get_data()
        obs_filter = self.sampler.get_obs_filter(flush=True)

        traj = process_rollout(rollout,
                               self.rew_filter,
                               gamma=0.99,
                               lambda_=1.0,
                               use_gae=True)
        gradient, info = self.policy.compute_gradients(traj)
        info["obs_filter"] = obs_filter
        info["rew_filter"] = self.rew_filter
        return gradient, info

    def apply_gradient(self, grads):
        self.policy.apply_gradients(grads)

    def set_weights(self, params):
        self.policy.set_weights(params)

    def update_filters(self, obs_filter=None, rew_filter=None):
        if rew_filter:
            # No special handling required since outside of threaded code
            self.rew_filter = rew_filter.copy()
        if obs_filter:
            self.sampler.update_obs_filter(obs_filter)
Пример #9
0
class A3CEvaluator(Evaluator):
    """Actor object to start running simulation on workers.

    The gradient computation is also executed from this object.

    Attributes:
        policy: Copy of graph used for policy. Used by sampler and gradients.
        obs_filter: Observation filter used in environment sampling
        rew_filter: Reward filter used in rollout post-processing.
        sampler: Component for interacting with environment and generating
            rollouts.
        logdir: Directory for logging.
    """
    def __init__(self,
                 registry,
                 env_creator,
                 config,
                 logdir,
                 start_sampler=True):
        env = ModelCatalog.get_preprocessor_as_wrapper(
            registry, env_creator(config["env_config"]), config["model"])
        self.env = env
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(registry, env.observation_space.shape,
                                 env.action_space, config)
        self.config = config

        # Technically not needed when not remote
        self.obs_filter = get_filter(config["observation_filter"],
                                     env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.filters = {
            "obs_filter": self.obs_filter,
            "rew_filter": self.rew_filter
        }
        self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
                                    config["batch_size"])
        if start_sampler and self.sampler. async:
            self.sampler.start()
        self.logdir = logdir

    def sample(self):
        rollout = self.sampler.get_data()
        samples = process_rollout(rollout,
                                  self.rew_filter,
                                  gamma=self.config["gamma"],
                                  lambda_=self.config["lambda"],
                                  use_gae=True)
        return samples

    def get_completed_rollout_metrics(self):
        """Returns metrics on previously completed rollouts.

        Calling this clears the queue of completed rollout metrics.
        """
        return self.sampler.get_metrics()

    def compute_gradients(self, samples):
        gradient, info = self.policy.compute_gradients(samples)
        return gradient

    def apply_gradients(self, grads):
        self.policy.apply_gradients(grads)

    def get_weights(self):
        return self.policy.get_weights()

    def set_weights(self, params):
        self.policy.set_weights(params)

    def save(self):
        filters = self.get_filters(flush_after=True)
        weights = self.get_weights()
        return pickle.dumps({"filters": filters, "weights": weights})

    def restore(self, objs):
        objs = pickle.loads(objs)
        self.sync_filters(objs["filters"])
        self.set_weights(objs["weights"])

    def sync_filters(self, new_filters):
        """Changes self's filter to given and rebases any accumulated delta.

        Args:
            new_filters (dict): Filters with new state to update local copy.
        """
        assert all(k in new_filters for k in self.filters)
        for k in self.filters:
            self.filters[k].sync(new_filters[k])

    def get_filters(self, flush_after=False):
        """Returns a snapshot of filters.

        Args:
            flush_after (bool): Clears the filter buffer state.

        Returns:
            return_filters (dict): Dict for serializable filters
        """
        return_filters = {}
        for k, f in self.filters.items():
            return_filters[k] = f.as_serializable()
            if flush_after:
                f.clear_buffer()
        return return_filters
Пример #10
0
    def __init__(self,
                 env_creator,
                 policy_graph,
                 policy_mapping_fn=None,
                 tf_session_creator=None,
                 batch_steps=100,
                 batch_mode="truncate_episodes",
                 episode_horizon=None,
                 preprocessor_pref="rllib",
                 sample_async=False,
                 compress_observations=False,
                 num_envs=1,
                 observation_filter="NoFilter",
                 env_config=None,
                 model_config=None,
                 policy_config=None,
                 worker_index=0):
        """Initialize a policy evaluator.

        Arguments:
            env_creator (func): Function that returns a gym.Env given an
                EnvContext wrapped configuration.
            policy_graph (class|dict): Either a class implementing
                PolicyGraph, or a dictionary of policy id strings to
                (PolicyGraph, obs_space, action_space, config) tuples. If a
                dict is specified, then we are in multi-agent mode and a
                policy_mapping_fn should also be set.
            policy_mapping_fn (func): A function that maps agent ids to
                policy ids in multi-agent mode. This function will be called
                each time a new agent appears in an episode, to bind that agent
                to a policy for the duration of the episode.
            tf_session_creator (func): A function that returns a TF session.
                This is optional and only useful with TFPolicyGraph.
            batch_steps (int): The target number of env transitions to include
                in each sample batch returned from this evaluator.
            batch_mode (str): One of the following batch modes:
                "truncate_episodes": Each call to sample() will return a batch
                    of exactly `batch_steps` in size. Episodes may be truncated
                    in order to meet this size requirement. When
                    `num_envs > 1`, episodes will be truncated to sequences of
                    `batch_size / num_envs` in length.
                "complete_episodes": Each call to sample() will return a batch
                    of at least `batch_steps in size. Episodes will not be
                    truncated, but multiple episodes may be packed within one
                    batch to meet the batch size. Note that when
                    `num_envs > 1`, episode steps will be buffered until the
                    episode completes, and hence batches may contain
                    significant amounts of off-policy data.
            episode_horizon (int): Whether to stop episodes at this horizon.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations
                returned.
            num_envs (int): If more than one, will create multiple envs
                and vectorize the computation of actions. This has no effect if
                if the env already implements VectorEnv.
            observation_filter (str): Name of observation filter to use.
            env_config (dict): Config to pass to the env creator.
            model_config (dict): Config to use when creating the policy model.
            policy_config (dict): Config to pass to the policy. In the
                multi-agent case, this config will be merged with the
                per-policy configs specified by `policy_graph`.
            worker_index (int): For remote evaluators, this should be set to a
                non-zero and unique value. This index is passed to created envs
                through EnvContext so that envs can be configured per worker.
        """

        env_context = EnvContext(env_config or {}, worker_index)
        policy_config = policy_config or {}
        self.policy_config = policy_config
        model_config = model_config or {}
        policy_mapping_fn = (policy_mapping_fn
                             or (lambda agent_id: DEFAULT_POLICY_ID))
        self.env_creator = env_creator
        self.policy_graph = policy_graph
        self.batch_steps = batch_steps
        self.batch_mode = batch_mode
        self.compress_observations = compress_observations

        self.env = env_creator(env_context)
        if isinstance(self.env, VectorEnv) or \
                isinstance(self.env, ServingEnv) or \
                isinstance(self.env, MultiAgentEnv) or \
                isinstance(self.env, AsyncVectorEnv):

            def wrap(env):
                return env  # we can't auto-wrap these env types
        elif is_atari(self.env) and \
                "custom_preprocessor" not in model_config and \
                preprocessor_pref == "deepmind":

            def wrap(env):
                return wrap_deepmind(env, dim=model_config.get("dim", 80))
        else:

            def wrap(env):
                return ModelCatalog.get_preprocessor_as_wrapper(
                    env, model_config)

        self.env = wrap(self.env)

        def make_env():
            return wrap(env_creator(env_context))

        self.tf_sess = None
        policy_dict = _validate_and_canonicalize(policy_graph, self.env)
        if _has_tensorflow_graph(policy_dict):
            with tf.Graph().as_default():
                if tf_session_creator:
                    self.tf_sess = tf_session_creator()
                else:
                    self.tf_sess = tf.Session(config=tf.ConfigProto(
                        gpu_options=tf.GPUOptions(allow_growth=True)))
                with self.tf_sess.as_default():
                    self.policy_map = self._build_policy_map(
                        policy_dict, policy_config)
        else:
            self.policy_map = self._build_policy_map(policy_dict,
                                                     policy_config)

        self.multiagent = self.policy_map.keys() != set(DEFAULT_POLICY_ID)

        self.filters = {
            policy_id: get_filter(observation_filter,
                                  policy.observation_space.shape)
            for (policy_id, policy) in self.policy_map.items()
        }

        # Always use vector env for consistency even if num_envs = 1
        self.async_env = AsyncVectorEnv.wrap_async(self.env,
                                                   make_env=make_env,
                                                   num_envs=num_envs)

        if self.batch_mode == "truncate_episodes":
            if batch_steps % num_envs != 0:
                raise ValueError(
                    "In 'truncate_episodes' batch mode, `batch_steps` must be "
                    "evenly divisible by `num_envs`. Got {} and {}.".format(
                        batch_steps, num_envs))
            batch_steps = batch_steps // num_envs
            pack_episodes = True
        elif self.batch_mode == "complete_episodes":
            batch_steps = float("inf")  # never cut episodes
            pack_episodes = False  # sampler will return 1 episode per poll
        else:
            raise ValueError("Unsupported batch mode: {}".format(
                self.batch_mode))
        if sample_async:
            self.sampler = AsyncSampler(self.async_env,
                                        self.policy_map,
                                        policy_mapping_fn,
                                        self.filters,
                                        batch_steps,
                                        horizon=episode_horizon,
                                        pack=pack_episodes,
                                        tf_sess=self.tf_sess)
            self.sampler.start()
        else:
            self.sampler = SyncSampler(self.async_env,
                                       self.policy_map,
                                       policy_mapping_fn,
                                       self.filters,
                                       batch_steps,
                                       horizon=episode_horizon,
                                       pack=pack_episodes,
                                       tf_sess=self.tf_sess)
Пример #11
0
    def __init__(self,
                 env_creator,
                 policy_graph,
                 tf_session_creator=None,
                 batch_steps=100,
                 batch_mode="truncate_episodes",
                 preprocessor_pref="rllib",
                 sample_async=False,
                 compress_observations=False,
                 observation_filter="NoFilter",
                 registry=None,
                 env_config=None,
                 model_config=None,
                 policy_config=None):
        """Initialize a policy evaluator.

        Arguments:
            env_creator (func): Function that returns a gym.Env given an
                env config dict.
            policy_graph (class): A class implementing rllib.PolicyGraph or
                rllib.TFPolicyGraph.
            tf_session_creator (func): A function that returns a TF session.
                This is optional and only useful with TFPolicyGraph.
            batch_steps (int): The target number of env transitions to include
                in each sample batch returned from this evaluator.
            batch_mode (str): One of the following choices:
                complete_episodes: each batch will be at least batch_steps
                    in size, and will include one or more complete episodes.
                truncate_episodes: each batch will be around batch_steps
                    in size, and include transitions from one episode only.
                pack_episodes: each batch will be exactly batch_steps in
                    size, and may include transitions from multiple episodes.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations
                returned.
            observation_filter (str): Name of observation filter to use.
            registry (tune.Registry): User-registered objects. Pass in the
                value from tune.registry.get_registry() if you're having
                trouble resolving things like custom envs.
            env_config (dict): Config to pass to the env creator.
            model_config (dict): Config to use when creating the policy model.
            policy_config (dict): Config to pass to the policy.
        """

        registry = registry or get_registry()
        env_config = env_config or {}
        policy_config = policy_config or {}
        model_config = model_config or {}

        assert batch_mode in [
            "complete_episodes", "truncate_episodes", "pack_episodes"
        ]
        self.env_creator = env_creator
        self.policy_graph = policy_graph
        self.batch_steps = batch_steps
        self.batch_mode = batch_mode
        self.compress_observations = compress_observations

        self.env = env_creator(env_config)
        is_atari = hasattr(self.env.unwrapped, "ale")
        if is_atari and "custom_preprocessor" not in model_config and \
                preprocessor_pref == "deepmind":
            self.env = wrap_deepmind(self.env, dim=model_config.get("dim", 80))
        else:
            self.env = ModelCatalog.get_preprocessor_as_wrapper(
                registry, self.env, model_config)

        self.vectorized = hasattr(self.env, "vector_reset")
        self.policy_map = {}

        if issubclass(policy_graph, TFPolicyGraph):
            with tf.Graph().as_default():
                if tf_session_creator:
                    self.sess = tf_session_creator()
                else:
                    self.sess = tf.Session(config=tf.ConfigProto(
                        gpu_options=tf.GPUOptions(allow_growth=True)))
                with self.sess.as_default():
                    policy = policy_graph(self.env.observation_space,
                                          self.env.action_space, registry,
                                          policy_config)
        else:
            policy = policy_graph(self.env.observation_space,
                                  self.env.action_space, registry,
                                  policy_config)
        self.policy_map = {"default": policy}

        self.obs_filter = get_filter(observation_filter,
                                     self.env.observation_space.shape)
        self.filters = {"obs_filter": self.obs_filter}

        if self.vectorized:
            raise NotImplementedError("Vector envs not yet supported")
        else:
            if batch_mode not in [
                    "pack_episodes", "truncate_episodes", "complete_episodes"
            ]:
                raise NotImplementedError("Batch mode not yet supported")
            pack = batch_mode == "pack_episodes"
            if batch_mode == "complete_episodes":
                batch_steps = 999999
            if sample_async:
                self.sampler = AsyncSampler(self.env,
                                            self.policy_map["default"],
                                            self.obs_filter,
                                            batch_steps,
                                            pack=pack)
                self.sampler.start()
            else:
                self.sampler = SyncSampler(self.env,
                                           self.policy_map["default"],
                                           self.obs_filter,
                                           batch_steps,
                                           pack=pack)