Ejemplo n.º 1
0
Archivo: sampler.py Proyecto: Yard1/ray
    def __init__(
        self,
        *,
        worker: "RolloutWorker",
        env: BaseEnv,
        clip_rewards: bool,
        rollout_fragment_length: int,
        count_steps_by: str = "env_steps",
        callbacks: "DefaultCallbacks",
        horizon: int = None,
        multiple_episodes_in_batch: bool = False,
        normalize_actions: bool = True,
        clip_actions: bool = False,
        blackhole_outputs: bool = False,
        soft_horizon: bool = False,
        no_done_at_end: bool = False,
        observation_fn: "ObservationFunction" = None,
        sample_collector_class: Optional[Type[SampleCollector]] = None,
        render: bool = False,
        # Obsolete.
        policies=None,
        policy_mapping_fn=None,
        preprocessors=None,
        obs_filters=None,
        tf_sess=None,
    ):
        """Initializes a AsyncSampler object.

        Args:
            worker (RolloutWorker): The RolloutWorker that will use this
                Sampler for sampling.
            env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
            clip_rewards (Union[bool, float]): True for +/-1.0 clipping, actual
                float value for +/- value clipping. False for no clipping.
            rollout_fragment_length (int): The length of a fragment to collect
                before building a SampleBatch from the data and resetting
                the SampleBatchBuilder object.
            count_steps_by (str): Either "env_steps" or "agent_steps".
                Refers to the unit of `rollout_fragment_length`.
            callbacks (Callbacks): The Callbacks object to use when episode
                events happen during rollout.
            horizon (Optional[int]): Hard-reset the Env
            multiple_episodes_in_batch (bool): Whether to pack multiple
                episodes into each batch. This guarantees batches will be
                exactly `rollout_fragment_length` in size.
            normalize_actions (bool): Whether to normalize actions to the
                action space's bounds.
            clip_actions (bool): Whether to clip actions according to the
                given action_space's bounds.
            blackhole_outputs (bool): Whether to collect samples, but then
                not further process or store them (throw away all samples).
            soft_horizon (bool): If True, calculate bootstrapped values as if
                episode had ended, but don't physically reset the environment
                when the horizon is hit.
            no_done_at_end (bool): Ignore the done=True at the end of the
                episode and instead record done=False.
            observation_fn (Optional[ObservationFunction]): Optional
                multi-agent observation func to use for preprocessing
                observations.
            sample_collector_class (Optional[Type[SampleCollector]]): An
                optional Samplecollector sub-class to use to collect, store,
                and retrieve environment-, model-, and sampler data.
            render (bool): Whether to try to render the environment after each
                step.
        """
        # All of the following arguments are deprecated. They will instead be
        # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
        if log_once("deprecated_async_sampler_args"):
            if policies is not None:
                deprecation_warning(old="policies")
            if policy_mapping_fn is not None:
                deprecation_warning(old="policy_mapping_fn")
            if preprocessors is not None:
                deprecation_warning(old="preprocessors")
            if obs_filters is not None:
                deprecation_warning(old="obs_filters")
            if tf_sess is not None:
                deprecation_warning(old="tf_sess")

        self.worker = worker

        for _, f in worker.filters.items():
            assert getattr(f, "is_concurrent", False), \
                "Observation Filter must support concurrent updates."

        self.base_env = BaseEnv.to_base_env(env)
        threading.Thread.__init__(self)
        self.queue = queue.Queue(5)
        self.extra_batches = queue.Queue()
        self.metrics_queue = queue.Queue()
        self.rollout_fragment_length = rollout_fragment_length
        self.horizon = horizon
        self.clip_rewards = clip_rewards
        self.daemon = True
        self.multiple_episodes_in_batch = multiple_episodes_in_batch
        self.callbacks = callbacks
        self.normalize_actions = normalize_actions
        self.clip_actions = clip_actions
        self.blackhole_outputs = blackhole_outputs
        self.soft_horizon = soft_horizon
        self.no_done_at_end = no_done_at_end
        self.perf_stats = _PerfStats()
        self.shutdown = False
        self.observation_fn = observation_fn
        self.render = render
        if not sample_collector_class:
            sample_collector_class = SimpleListCollector
        self.sample_collector = sample_collector_class(
            self.worker.policy_map,
            self.clip_rewards,
            self.callbacks,
            self.multiple_episodes_in_batch,
            self.rollout_fragment_length,
            count_steps_by=count_steps_by)
Ejemplo n.º 2
0
Archivo: pbt.py Proyecto: yynst2/ray
    def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
                        trial: Trial, result: Dict) -> str:
        if self._time_attr not in result:
            time_missing_msg = "Cannot find time_attr {} " \
                               "in trial result {}. Make sure that this " \
                               "attribute is returned in the " \
                               "results of your Trainable.".format(
                                self._time_attr, result)
            if self._require_attrs:
                raise RuntimeError(
                    time_missing_msg +
                    "If this error is expected, you can change this to "
                    "a warning message by "
                    "setting PBT(require_attrs=False)")
            else:
                if log_once("pbt-time_attr-error"):
                    logger.warning(time_missing_msg)
        if self._metric not in result:
            metric_missing_msg = "Cannot find metric {} in trial result {}. " \
                                 "Make sure that this attribute is returned " \
                                 "in the " \
                                 "results of your Trainable.".format(
                                    self._metric, result)
            if self._require_attrs:
                raise RuntimeError(
                    metric_missing_msg + "If this error is expected, "
                    "you can change this to a warning message by "
                    "setting PBT(require_attrs=False)")
            else:
                if log_once("pbt-metric-error"):
                    logger.warning(metric_missing_msg)

        if self._metric not in result or self._time_attr not in result:
            return TrialScheduler.CONTINUE

        time = result[self._time_attr]
        state = self._trial_state[trial]

        # Continue training if perturbation interval has not been reached yet.
        if time - state.last_perturbation_time < self._perturbation_interval:
            return TrialScheduler.CONTINUE  # avoid checkpoint overhead

        # This trial has reached its perturbation interval
        score = self._metric_op * result[self._metric]
        state.last_score = score
        state.last_train_time = time
        state.last_result = result

        if not self._synch:
            state.last_perturbation_time = time
            lower_quantile, upper_quantile = self._quantiles()
            self._perturb_trial(trial, trial_runner, upper_quantile,
                                lower_quantile)
            for trial in trial_runner.get_trials():
                if trial.status in [Trial.PENDING, Trial.PAUSED]:
                    return TrialScheduler.PAUSE  # yield time to other trials

            return TrialScheduler.CONTINUE
        else:
            # Synchronous mode.
            if any(self._trial_state[t].last_train_time <
                   self._next_perturbation_sync and t != trial
                   for t in trial_runner.get_trials()):
                logger.debug("Pausing trial {}".format(trial))
            else:
                # All trials are synced at the same timestep.
                lower_quantile, upper_quantile = self._quantiles()
                all_trials = trial_runner.get_trials()
                not_in_quantile = []
                for t in all_trials:
                    if t not in lower_quantile and t not in upper_quantile:
                        not_in_quantile.append(t)
                # Move upper quantile trials to beginning and lower quantile
                # to end. This ensures that checkpointing of strong trials
                # occurs before exploiting of weaker ones.
                all_trials = upper_quantile + not_in_quantile + lower_quantile
                for t in all_trials:
                    logger.debug("Perturbing Trial {}".format(t))
                    self._trial_state[t].last_perturbation_time = time
                    self._perturb_trial(t, trial_runner, upper_quantile,
                                        lower_quantile)

                all_train_times = [
                    self._trial_state[trial].last_train_time
                    for trial in trial_runner.get_trials()
                ]
                max_last_train_time = max(all_train_times)
                self._next_perturbation_sync = max(
                    self._next_perturbation_sync + self._perturbation_interval,
                    max_last_train_time)
            # In sync mode we should pause all trials once result comes in.
            # Once a perturbation step happens for all trials, they should
            # still all be paused.
            # choose_trial_to_run will then pick the next trial to run out of
            # the paused trials.
            return TrialScheduler.PAUSE
Ejemplo n.º 3
0
from ray._private.metrics_agent import MetricsAgent, Gauge, Record
from ray.util.debug import log_once
import psutil

logger = logging.getLogger(__name__)

enable_gpu_usage_check = True

# Are we in a K8s pod?
IN_KUBERNETES_POD = "KUBERNETES_SERVICE_HOST" in os.environ

try:
    import gpustat.core as gpustat
except (ModuleNotFoundError, ImportError):
    gpustat = None
    if log_once("gpustat_import_warning"):
        warnings.warn(
            "`gpustat` package is not installed. GPU monitoring is "
            "not available. To have full functionality of the "
            "dashboard please install `pip install ray["
            "default]`.)"
        )


def recursive_asdict(o):
    if isinstance(o, tuple) and hasattr(o, "_asdict"):
        return recursive_asdict(o._asdict())

    if isinstance(o, (tuple, list)):
        L = []
        for k in o:
Ejemplo n.º 4
0
def _process_observations(
        worker: "RolloutWorker", base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        batch_builder_pool: List[MultiAgentSampleBatchBuilder],
        active_episodes: Dict[str, MultiAgentEpisode],
        unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
        rewards: Dict[EnvID, Dict[AgentID, float]],
        dones: Dict[EnvID, Dict[AgentID, bool]],
        infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int,
        preprocessors: Dict[PolicyID, Preprocessor],
        obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int,
        pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks",
        soft_horizon: bool, no_done_at_end: bool,
        observation_fn: "ObservationFunction",
        _use_trajectory_view_api: bool = False
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
        RolloutMetrics, SampleBatchType]]]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        policies (dict): Map of policy ids to Policy instances.
        batch_builder_pool (List[SampleBatchBuilder]): List of pooled
            SampleBatchBuilder object for recycling.
        active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
            unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
        rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones (dict): Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos (dict): Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        pack_multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        _use_trajectory_view_api (bool): Whether to use the (experimental)
            `_use_trajectory_view_api` to make generic trajectory views
            available to Models. Default: False.

    Returns:
        Tuple:
            - active_envs: Set of non-terminated env ids.
            - to_eval: Map of policy_id to list of agent PolicyEvalData.
            - outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
        rollout_fragment_length != float("inf") else 5000

    # For each environment.
    # type: EnvID, Dict[AgentID, EnvObsType]
    for env_id, agent_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: MultiAgentEpisode = active_episodes[env_id]
        if not is_new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > large_batch_threshold
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check "
                "that you set a horizon on your environment correctly and "
                "that it terminates at some point. "
                "Note: In multi-agent environments, `rollout_fragment_length` "
                "sets the batch size based on environment steps, not the "
                "steps of "
                "individual agents, which can result in unexpectedly large "
                "batches. Also, you may be in evaluation waiting for your Env "
                "to terminate (batch_mode=`complete_episodes`). Make sure it "
                "does at some point.")

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
                base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data))
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=agent_obs,
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if not isinstance(agent_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        # type: AgentID, EnvObsType
        for agent_id, raw_obs in agent_obs.items():
            assert agent_id != "__all__"
            policy_id: PolicyID = episode.policy_for(agent_id)
            prep_obs: EnvObsType = _get_or_raise(preprocessors,
                                                 policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs: EnvObsType = _get_or_raise(obs_filters,
                                                     policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation: EnvObsType = episode.last_observation_for(
                agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable.
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if (no_done_at_end
                                     or (hit_horizon and soft_horizon)) else
                           agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        callbacks.on_episode_step(
            worker=worker, base_env=base_env, episode=episode)

        # Cut the batch if ...
        # - all-agents-done and not packing multiple episodes into one
        #   (batch_mode="complete_episodes")
        # - or if we've exceeded the rollout_fragment_length.
        if episode.batch_builder.has_pending_agent_data():
            # Sanity check, whether all agents have done=True, if done[__all__]
            # is True.
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()

            # Reached end of episode and we are not allowed to pack the
            # next episode into the same SampleBatch -> Build the SampleBatch
            # and add it to "outputs".
            if (all_agents_done and not pack_multiple_episodes_in_batch) or \
                    episode.batch_builder.count >= rollout_fragment_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            # Make sure postprocessor stays within one episode.
            elif all_agents_done:
                episode.batch_builder.postprocess_batch_so_far(episode)

        # Episode is done.
        if all_agents_done:
            # Handle episode termination.
            batch_builder_pool.append(episode.batch_builder)
            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(
                        policy=p,
                        environment=base_env,
                        episode=episode,
                        tf_sess=getattr(p, "_sess", None))
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
                    env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list.
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return.
                # If reset is async, we will get its result in some future poll
                episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=policies,
                        episode=episode)
                # type: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = episode.policy_for(agent_id)
                    policy: Policy = _get_or_raise(policies, policy_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        obs_filters, policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                flatten_to_single_ndarray(
                                    policy.action_space.sample())), 0.0))

    return active_envs, to_eval, outputs
Ejemplo n.º 5
0
def _bootstrap_config(config: Dict[str, Any],
                      no_config_cache: bool = False) -> Dict[str, Any]:
    config = prepare_config(config)
    # NOTE: multi-node-type autoscaler is guaranteed to be in use after this.

    hasher = hashlib.sha1()
    hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
    cache_key = os.path.join(tempfile.gettempdir(),
                             "ray-config-{}".format(hasher.hexdigest()))

    if os.path.exists(cache_key) and not no_config_cache:
        config_cache = json.loads(open(cache_key).read())
        if config_cache.get("_version", -1) == CONFIG_CACHE_VERSION:
            # todo: is it fine to re-resolve? afaik it should be.
            # we can have migrations otherwise or something
            # but this seems overcomplicated given that resolving is
            # relatively cheap
            try_reload_log_state(config_cache["config"]["provider"],
                                 config_cache.get("provider_log_info"))

            if log_once("_printed_cached_config_warning"):
                cli_logger.verbose_warning(
                    "Loaded cached provider configuration "
                    "from " + cf.bold("{}"), cache_key)
                if cli_logger.verbosity == 0:
                    cli_logger.warning("Loaded cached provider configuration")
                cli_logger.warning(
                    "If you experience issues with "
                    "the cloud provider, try re-running "
                    "the command with {}.", cf.bold("--no-config-cache"))

            return config_cache["config"]
        else:
            cli_logger.warning(
                "Found cached cluster config "
                "but the version " + cf.bold("{}") + " "
                "(expected " + cf.bold("{}") + ") does not match.\n"
                "This is normal if cluster launcher was updated.\n"
                "Config will be re-resolved.",
                config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)

    importer = _NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    provider_cls = importer(config["provider"])

    cli_logger.print("Checking {} environment settings",
                     _PROVIDER_PRETTY_NAMES.get(config["provider"]["type"]))
    try:
        config = provider_cls.fillout_available_node_types_resources(config)
    except Exception as exc:
        if cli_logger.verbosity > 2:
            logger.exception("Failed to autodetect node resources.")
        else:
            cli_logger.warning(
                f"Failed to autodetect node resources: {str(exc)}. "
                "You can see full stack trace with higher verbosity.")

    # NOTE: if `resources` field is missing, validate_config for providers
    # other than AWS and Kubernetes will fail (the schema error will ask the
    # user to manually fill the resources) as we currently support autofilling
    # resources for AWS and Kubernetes only.
    validate_config(config)
    resolved_config = provider_cls.bootstrap_config(config)

    if not no_config_cache:
        with open(cache_key, "w") as f:
            config_cache = {
                "_version": CONFIG_CACHE_VERSION,
                "provider_log_info": try_get_log_state(config["provider"]),
                "config": resolved_config
            }
            f.write(json.dumps(config_cache))
    return resolved_config
Ejemplo n.º 6
0
def wrap_function(train_func: Callable[[Any], Any],
                  warn: bool = True,
                  name: Optional[str] = None) -> Type["FunctionTrainable"]:
    inherit_from = (FunctionTrainable, )

    if hasattr(train_func, "__mixins__"):
        inherit_from = train_func.__mixins__ + inherit_from

    func_args = inspect.getfullargspec(train_func).args
    use_checkpoint = detect_checkpoint_function(train_func)
    use_config_single = detect_config_single(train_func)
    use_reporter = detect_reporter(train_func)

    if not any([use_checkpoint, use_config_single, use_reporter]):
        # use_reporter is hidden
        raise ValueError(
            "Unknown argument found in the Trainable function. "
            "The function args must include a 'config' positional "
            "parameter. Any other args must be 'checkpoint_dir'. "
            "Found: {}".format(func_args))

    if use_config_single and not use_checkpoint:
        if log_once("tune_function_checkpoint") and warn:
            logger.warning(
                "Function checkpointing is disabled. This may result in "
                "unexpected behavior when using checkpointing features or "
                "certain schedulers. To enable, set the train function "
                "arguments to be `func(config, checkpoint_dir=None)`.")

    if use_checkpoint:
        if log_once("tune_checkpoint_dir_deprecation") and warn:
            with warnings.catch_warnings():
                warnings.simplefilter("always")
                warning_msg = (
                    "`checkpoint_dir` in `func(config, checkpoint_dir)` is "
                    "being deprecated. "
                    "To save and load checkpoint in trainable functions, "
                    "please use the `ray.air.session` API:\n\n"
                    "from ray.air import session\n\n"
                    "def train(config):\n"
                    "    # ...\n"
                    '    session.report({"metric": metric}, checkpoint=checkpoint)\n\n'
                    "For more information please see "
                    "https://docs.ray.io/en/master/ray-air/key-concepts.html#session\n"
                )
                warnings.warn(
                    warning_msg,
                    DeprecationWarning,
                )

    class ImplicitFunc(*inherit_from):
        _name = name or (train_func.__name__ if hasattr(
            train_func, "__name__") else "func")

        def __repr__(self):
            return self._name

        def _trainable_func(self, config, reporter, checkpoint_dir):
            if not use_checkpoint and not use_reporter:
                fn = partial(train_func, config)
            elif use_checkpoint:
                fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
            else:
                fn = partial(train_func, config, reporter)

            def handle_output(output):
                if not output:
                    return
                elif isinstance(output, dict):
                    reporter(**output)
                elif isinstance(output, Number):
                    reporter(_metric=output)
                else:
                    raise ValueError(
                        "Invalid return or yield value. Either return/yield "
                        "a single number or a dictionary object in your "
                        "trainable function.")

            output = None
            if inspect.isgeneratorfunction(train_func):
                for output in fn():
                    handle_output(output)
            else:
                output = fn()
                handle_output(output)

            # If train_func returns, we need to notify the main event loop
            # of the last result while avoiding double logging. This is done
            # with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
            reporter(**{RESULT_DUPLICATE: True})
            return output

    return ImplicitFunc
Ejemplo n.º 7
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in = [t.rnn_state for t in eval_data]
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):

            obs_batch = [t.obs for t in eval_data]
            state_batches = _to_column_format(rnn_in)
            # TODO(ekl): how can we make info batch available to TF code?
            prev_action_batch = [t.prev_action for t in eval_data]
            prev_reward_batch = [t.prev_reward for t in eval_data]

            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=obs_batch,
                state_batches=state_batches,
                prev_action_batch=prev_action_batch,
                prev_reward_batch=prev_reward_batch,
                timestep=policy.global_timestep)
        else:
            # TODO(sven): Does this work for LSTM torch?
            rnn_in_cols = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Ejemplo n.º 8
0
    def _connect_channel(self, reconnecting=False) -> None:
        """
        Attempts to connect to the server specified by conn_str. If
        reconnecting after an RPC error, cleans up the old channel and
        continues to attempt to connect until the grace period is over.
        """
        if self.channel is not None:
            self.channel.unsubscribe(self._on_channel_state_change)
            self.channel.close()

        if self._secure:
            if self._credentials is not None:
                credentials = self._credentials
            elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
                (
                    server_cert_chain,
                    private_key,
                    ca_cert,
                ) = ray._private.utils.load_certs_from_env()
                credentials = grpc.ssl_channel_credentials(
                    certificate_chain=server_cert_chain,
                    private_key=private_key,
                    root_certificates=ca_cert,
                )
            else:
                credentials = grpc.ssl_channel_credentials()
            self.channel = grpc.secure_channel(self._conn_str,
                                               credentials,
                                               options=GRPC_OPTIONS)
        else:
            self.channel = grpc.insecure_channel(self._conn_str,
                                                 options=GRPC_OPTIONS)

        self.channel.subscribe(self._on_channel_state_change)

        # Retry the connection until the channel responds to something
        # looking like a gRPC connection, though it may be a proxy.
        start_time = time.time()
        conn_attempts = 0
        timeout = INITIAL_TIMEOUT_SEC
        service_ready = False
        while conn_attempts < max(self._connection_retries, 1) or reconnecting:
            conn_attempts += 1
            if self._in_shutdown:
                # User manually closed the worker before connection finished
                break
            elapsed_time = time.time() - start_time
            if reconnecting and elapsed_time > self._reconnect_grace_period:
                self._in_shutdown = True
                raise ConnectionError(
                    "Failed to reconnect within the reconnection grace period "
                    f"({self._reconnect_grace_period}s)")
            try:
                # Let gRPC wait for us to see if the channel becomes ready.
                # If it throws, we couldn't connect.
                grpc.channel_ready_future(self.channel).result(timeout=timeout)
                # The HTTP2 channel is ready. Wrap the channel with the
                # RayletDriverStub, allowing for unary requests.
                self.server = ray_client_pb2_grpc.RayletDriverStub(
                    self.channel)
                service_ready = bool(self.ping_server())
                if service_ready:
                    break
                # Ray is not ready yet, wait a timeout
                time.sleep(timeout)
            except grpc.FutureTimeoutError:
                logger.debug(
                    f"Couldn't connect channel in {timeout} seconds, retrying")
                # Note that channel_ready_future constitutes its own timeout,
                # which is why we do not sleep here.
            except grpc.RpcError as e:
                logger.debug("Ray client server unavailable, "
                             f"retrying in {timeout}s...")
                logger.debug(f"Received when checking init: {e.details()}")
                # Ray is not ready yet, wait a timeout.
                time.sleep(timeout)
            # Fallthrough, backoff, and retry at the top of the loop
            logger.debug("Waiting for Ray to become ready on the server, "
                         f"retry in {timeout}s...")
            if not reconnecting:
                # Don't increase backoff when trying to reconnect --
                # we already know the server exists, attempt to reconnect
                # as soon as we can
                timeout = backoff(timeout)

        # If we made it through the loop without service_ready
        # it means we've used up our retries and
        # should error back to the user.
        if not service_ready:
            self._in_shutdown = True
            if log_once("ray_client_security_groups"):
                warnings.warn(
                    "Ray Client connection timed out. Ensure that "
                    "the Ray Client port on the head node is reachable "
                    "from your local machine. See https://docs.ray.io/en"
                    "/latest/cluster/ray-client.html#step-2-check-ports for "
                    "more information.")
            raise ConnectionError("ray client connection timeout")
Ejemplo n.º 9
0
    def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
                                num_returns: int) -> List[Future]:
        logger.debug(
            f"Scheduling task {task.name} {task.type} {task.payload_id}")
        task.client_id = self._client_id
        if num_returns is None:
            num_returns = 1

        id_futures = [Future() for _ in range(num_returns)]

        def populate_ids(
                resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
            if isinstance(resp, Exception):
                if isinstance(resp, grpc.RpcError):
                    resp = decode_exception(resp)
                for future in id_futures:
                    future.set_exception(resp)
                return

            ticket = resp.task_ticket
            if not ticket.valid:
                try:
                    ex = cloudpickle.loads(ticket.error)
                except (pickle.UnpicklingError, TypeError) as e_new:
                    ex = e_new
                for future in id_futures:
                    future.set_exception(ex)
                return

            if len(ticket.return_ids) != num_returns:
                exc = ValueError(
                    f"Expected {num_returns} returns but received "
                    f"{len(ticket.return_ids)}")
                for future, raw_id in zip(id_futures, ticket.return_ids):
                    future.set_exception(exc)
                return

            for future, raw_id in zip(id_futures, ticket.return_ids):
                future.set_result(raw_id)

        self.data_client.Schedule(task, populate_ids)

        self.total_outbound_message_size_bytes += task.ByteSize()
        if (self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD
                and log_once("client_communication_overhead_warning")):
            warnings.warn(
                "More than 10MB of messages have been created to schedule "
                "tasks on the server. This can be slow on Ray Client due to "
                "communication overhead over the network. If you're running "
                "many fine-grained tasks, consider running them inside a "
                'single remote function. See the section on "Too '
                'fine-grained tasks" in the Ray Design Patterns document for '
                f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
                "your functions frequently use large objects, consider "
                "storing the objects remotely with ray.put. An example of "
                'this is shown in the "Closure capture of large / '
                'unserializable object" section of the Ray Design Patterns '
                "document, available here: "
                f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}",
                UserWarning,
            )
        return id_futures
Ejemplo n.º 10
0
    def load_data(self, sess, inputs, state_inputs):
        """Bulk loads the specified inputs into device memory.

        The shape of the inputs must conform to the shapes of the input
        placeholders this optimizer was constructed with.

        The data is split equally across all the devices. If the data is not
        evenly divisible by the batch size, excess data will be discarded.

        Args:
            sess: TensorFlow session.
            inputs: List of arrays matching the input placeholders, of shape
                [BATCH_SIZE, ...].
            state_inputs: List of RNN input arrays. These arrays have size
                [BATCH_SIZE / MAX_SEQ_LEN, ...].

        Returns:
            The number of tuples loaded per device.
        """
        if log_once("load_data"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize({
                        "placeholders": self.loss_inputs,
                        "inputs": inputs,
                        "state_inputs": state_inputs
                    })))

        feed_dict = {}
        assert len(self.loss_inputs) == len(inputs + state_inputs), \
            (self.loss_inputs, inputs, state_inputs)

        # Let's suppose we have the following input data, and 2 devices:
        # 1 2 3 4 5 6 7                              <- state inputs shape
        # A A A B B B C C C D D D E E E F F F G G G  <- inputs shape
        # The data is truncated and split across devices as follows:
        # |---| seq len = 3
        # |---------------------------------| seq batch size = 6 seqs
        # |----------------| per device batch size = 9 tuples

        if len(state_inputs) > 0:
            smallest_array = state_inputs[0]
            seq_len = len(inputs[0]) // len(state_inputs[0])
            self._loaded_max_seq_len = seq_len
        else:
            smallest_array = inputs[0]
            self._loaded_max_seq_len = 1

        sequences_per_minibatch = (self.max_per_device_batch_size //
                                   self._loaded_max_seq_len *
                                   len(self.devices))
        if sequences_per_minibatch < 1:
            logger.warning(
                ("Target minibatch size is {}, however the rollout sequence "
                 "length is {}, hence the minibatch size will be raised to "
                 "{}.").format(self.max_per_device_batch_size,
                               self._loaded_max_seq_len,
                               self._loaded_max_seq_len * len(self.devices)))
            sequences_per_minibatch = 1

        if len(smallest_array) < sequences_per_minibatch:
            # Dynamically shrink the batch size if insufficient data
            sequences_per_minibatch = make_divisible_by(
                len(smallest_array), len(self.devices))

        if log_once("data_slicing"):
            logger.info(
                ("Divided {} rollout sequences, each of length {}, among "
                 "{} devices.").format(len(smallest_array),
                                       self._loaded_max_seq_len,
                                       len(self.devices)))

        if sequences_per_minibatch < len(self.devices):
            raise ValueError(
                "Must load at least 1 tuple sequence per device. Try "
                "increasing `sgd_minibatch_size` or reducing `max_seq_len` "
                "to ensure that at least one sequence fits per device.")
        self._loaded_per_device_batch_size = (sequences_per_minibatch //
                                              len(self.devices) *
                                              self._loaded_max_seq_len)

        if len(state_inputs) > 0:
            # First truncate the RNN state arrays to the sequences_per_minib.
            state_inputs = [
                make_divisible_by(arr, sequences_per_minibatch)
                for arr in state_inputs
            ]
            # Then truncate the data inputs to match
            inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs]
            assert len(state_inputs[0]) * seq_len == len(inputs[0]), \
                (len(state_inputs[0]), sequences_per_minibatch, seq_len,
                 len(inputs[0]))
            for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
                feed_dict[ph] = arr
            truncated_len = len(inputs[0])
        else:
            truncated_len = 0
            for ph, arr in zip(self.loss_inputs, inputs):
                truncated_arr = make_divisible_by(arr, sequences_per_minibatch)
                feed_dict[ph] = truncated_arr
                if truncated_len == 0:
                    truncated_len = len(truncated_arr)

        sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)

        self.num_tuples_loaded = truncated_len
        tuples_per_device = truncated_len // len(self.devices)
        assert tuples_per_device > 0, "No data loaded?"
        assert tuples_per_device % self._loaded_per_device_batch_size == 0
        return tuples_per_device
Ejemplo n.º 11
0
def ddpg_actor_critic_loss(policy, model, _, train_batch):
    twin_q = policy.config["twin_q"]
    gamma = policy.config["gamma"]
    n_step = policy.config["n_step"]
    use_huber = policy.config["use_huber"]
    huber_threshold = policy.config["huber_threshold"]
    l2_reg = policy.config["l2_reg"]

    input_dict = {
        "obs": train_batch[SampleBatch.CUR_OBS],
        "is_training": True,
    }
    input_dict_next = {
        "obs": train_batch[SampleBatch.NEXT_OBS],
        "is_training": True,
    }

    model_out_t, _ = model(input_dict, [], None)
    model_out_tp1, _ = model(input_dict_next, [], None)
    target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)

    # Policy network evaluation.
    policy_t = model.get_policy_output(model_out_t)
    policy_tp1 = \
        policy.target_model.get_policy_output(target_model_out_tp1)

    # Action outputs.
    if policy.config["smooth_target_policy"]:
        target_noise_clip = policy.config["target_noise_clip"]
        clipped_normal_sample = tf.clip_by_value(
            tf.random.normal(
                tf.shape(policy_tp1), stddev=policy.config["target_noise"]),
            -target_noise_clip, target_noise_clip)
        policy_tp1_smoothed = tf.clip_by_value(
            policy_tp1 + clipped_normal_sample,
            policy.action_space.low * tf.ones_like(policy_tp1),
            policy.action_space.high * tf.ones_like(policy_tp1))
    else:
        # No smoothing, just use deterministic actions.
        policy_tp1_smoothed = policy_tp1

    # Q-net(s) evaluation.
    # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    # Q-values for given actions & observations in given current
    q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])

    # Q-values for current policy (no noise) in given current state
    q_t_det_policy = model.get_q_values(model_out_t, policy_t)

    if twin_q:
        twin_q_t = model.get_twin_q_values(model_out_t,
                                           train_batch[SampleBatch.ACTIONS])

    # Target q-net(s) evaluation.
    q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
                                             policy_tp1_smoothed)

    if twin_q:
        twin_q_tp1 = policy.target_model.get_twin_q_values(
            target_model_out_tp1, policy_tp1_smoothed)

    q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
    if twin_q:
        twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
        q_tp1 = tf.minimum(q_tp1, twin_q_tp1)

    q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
    q_tp1_best_masked = \
        (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \
        q_tp1_best

    # Compute RHS of bellman equation.
    q_t_selected_target = tf.stop_gradient(train_batch[SampleBatch.REWARDS] +
                                           gamma**n_step * q_tp1_best_masked)

    # Compute the error (potentially clipped).
    if twin_q:
        td_error = q_t_selected - q_t_selected_target
        twin_td_error = twin_q_t_selected - q_t_selected_target
        if use_huber:
            errors = huber_loss(td_error, huber_threshold) + \
                huber_loss(twin_td_error, huber_threshold)
        else:
            errors = 0.5 * tf.math.square(td_error) + \
                     0.5 * tf.math.square(twin_td_error)
    else:
        td_error = q_t_selected - q_t_selected_target
        if use_huber:
            errors = huber_loss(td_error, huber_threshold)
        else:
            errors = 0.5 * tf.math.square(td_error)

    critic_loss = tf.reduce_mean(
        tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors)
    actor_loss = -tf.reduce_mean(q_t_det_policy)

    # Add l2-regularization if required.
    if l2_reg is not None:
        for var in policy.model.policy_variables():
            if "bias" not in var.name:
                actor_loss += (l2_reg * tf.nn.l2_loss(var))
        for var in policy.model.q_variables():
            if "bias" not in var.name:
                critic_loss += (l2_reg * tf.nn.l2_loss(var))

    # Model self-supervised losses.
    if policy.config["use_state_preprocessor"]:
        # Expand input_dict in case custom_loss' need them.
        input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
        input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
        input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
        input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
        if log_once("ddpg_custom_loss"):
            logger.warning(
                "You are using a state-preprocessor with DDPG and "
                "therefore, `custom_loss` will be called on your Model! "
                "Please be aware that DDPG now uses the ModelV2 API, which "
                "merges all previously separate sub-models (policy_model, "
                "q_model, and twin_q_model) into one ModelV2, on which "
                "`custom_loss` is called, passing it "
                "[actor_loss, critic_loss] as 1st argument. "
                "You may have to change your custom loss function to handle "
                "this.")
        [actor_loss, critic_loss] = model.custom_loss(
            [actor_loss, critic_loss], input_dict)

    # Store values for stats function.
    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.td_error = td_error
    policy.q_t = q_t

    # Return one loss value (even though we treat them separately in our
    # 2 optimizers: actor and critic).
    return policy.critic_loss + policy.actor_loss
Ejemplo n.º 12
0
    def _initialize_loss_from_dummy_batch(
            self,
            auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        logger.info("Testing `compute_actions` w/ dummy batch.")
        actions, state_outs, extra_fetches = \
            self.compute_actions_from_input_dict(
                self._dummy_batch, explore=False, timestep=0)
        for key, value in extra_fetches.items():
            self._dummy_batch[key] = value
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info("Adding extra-action-fetch `{}` to "
                            "view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=dummy_batch[key].shape[1:],
                                         dtype=dummy_batch[key].dtype),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens
            self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | dummy_batch.accessed_keys | \
            dummy_batch.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(
            self, loss,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] +
            ([("seq_lens",
               train_batch["seq_lens"])] if "seq_lens" in train_batch else []))

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self.get_session().run(tf1.global_variables_initializer())
Ejemplo n.º 13
0
Archivo: sampler.py Proyecto: Yard1/ray
def _process_observations(
    *,
    worker: "RolloutWorker",
    base_env: BaseEnv,
    active_episodes: Dict[str, MultiAgentEpisode],
    unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
    rewards: Dict[EnvID, Dict[AgentID, float]],
    dones: Dict[EnvID, Dict[AgentID, bool]],
    infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
    horizon: int,
    multiple_episodes_in_batch: bool,
    callbacks: "DefaultCallbacks",
    soft_horizon: bool,
    no_done_at_end: bool,
    observation_fn: "ObservationFunction",
    sample_collector: SampleCollector,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
        RolloutMetrics, SampleBatchType]]]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        batch_builder_pool (List[SampleBatchBuilder]): List of pooled
            SampleBatchBuilder object for recycling.
        active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
            -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
            call.
        rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones (dict): Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos (dict): Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon (int): Horizon of the episode.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        sample_collector (SampleCollector): The SampleCollector object
            used to store and retrieve environment samples.

    Returns:
        Tuple:
            - active_envs: Set of non-terminated env ids.
            - to_eval: Map of policy_id to list of agent PolicyEvalData.
            - outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    # For each (vectorized) sub-environment.
    # types: EnvID, Dict[AgentID, EnvObsType]
    for env_id, all_agents_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: MultiAgentEpisode = active_episodes[env_id]

        if not is_new_episode:
            sample_collector.episode_step(episode)
            episode._add_agent_rewards(rewards[env_id])

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
                base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data, episode.media))
            # Check whether we have to create a fake-last observation
            # for some agents (the environment is not required to do so if
            # dones[__all__]=True).
            for ag_id in episode.get_agents():
                if not episode.last_done_for(
                        ag_id) and ag_id not in all_agents_obs:
                    # Create a fake (all-0s) observation.
                    obs_sp = worker.policy_map[episode.policy_for(
                        ag_id)].observation_space
                    obs_sp = getattr(obs_sp, "original_space", obs_sp)
                    all_agents_obs[ag_id] = tree.map_structure(
                        np.zeros_like, obs_sp.sample())
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=all_agents_obs,
                worker=worker,
                base_env=base_env,
                policies=worker.policy_map,
                episode=episode)
            if not isinstance(all_agents_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        # types: AgentID, EnvObsType
        for agent_id, raw_obs in all_agents_obs.items():
            assert agent_id != "__all__"

            last_observation: EnvObsType = episode.last_observation_for(
                agent_id)
            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))

            # A new agent (initial obs) is already done -> Skip entirely.
            if last_observation is None and agent_done:
                continue

            policy_id: PolicyID = episode.policy_for(agent_id)

            prep_obs: EnvObsType = _get_or_raise(worker.preprocessors,
                                                 policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
            filtered_obs: EnvObsType = _get_or_raise(worker.filters,
                                                     policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_done(agent_id, agent_done)
            # Infos from the environment.
            agent_infos = infos[env_id].get(agent_id, {})
            episode._set_last_info(agent_id, agent_infos)

            # Record transition info if applicable.
            if last_observation is None:
                sample_collector.add_init_obs(episode, agent_id, env_id,
                                              policy_id, episode.length - 1,
                                              filtered_obs)
            else:
                # Add actions, rewards, next-obs to collectors.
                values_dict = {
                    "t":
                    episode.length - 1,
                    "env_id":
                    env_id,
                    "agent_index":
                    episode._agent_index(agent_id),
                    # Action (slot 0) taken at timestep t.
                    "actions":
                    episode.last_action_for(agent_id),
                    # Reward received after taking a at timestep t.
                    "rewards":
                    rewards[env_id].get(agent_id, 0.0),
                    # After taking action=a, did we reach terminal?
                    "dones":
                    (False if
                     (no_done_at_end or
                      (hit_horizon and soft_horizon)) else agent_done),
                    # Next observation.
                    "new_obs":
                    filtered_obs,
                }
                # Add extra-action-fetches to collectors.
                pol = worker.policy_map[policy_id]
                for key, value in episode.last_pi_info_for(agent_id).items():
                    if key in pol.view_requirements:
                        values_dict[key] = value
                # Env infos for this agent.
                if "infos" in pol.view_requirements:
                    values_dict["infos"] = agent_infos
                sample_collector.add_action_reward_next_obs(
                    episode.episode_id, agent_id, env_id, policy_id,
                    agent_done, values_dict)

            if not agent_done:
                item = PolicyEvalData(
                    env_id, agent_id, filtered_obs, agent_infos,
                    None if last_observation is None else
                    episode.rnn_state_for(agent_id),
                    None if last_observation is None else
                    episode.last_action_for(agent_id),
                    rewards[env_id].get(agent_id, 0.0))
                to_eval[policy_id].append(item)

        # Invoke the `on_episode_step` callback after the step is logged
        # to the episode.
        # Exception: The very first env.poll() call causes the env to get reset
        # (no step taken yet, just a single starting observation logged).
        # We need to skip this callback in this case.
        if episode.length > 0:
            callbacks.on_episode_step(worker=worker,
                                      base_env=base_env,
                                      episode=episode,
                                      env_index=env_id)

        # Episode is done for all agents (dones[__all__] == True)
        # or we hit the horizon.
        if all_agents_done:
            is_done = dones[env_id]["__all__"]
            check_dones = is_done and not no_done_at_end

            # If, we are not allowed to pack the next episode into the same
            # SampleBatch (batch_mode=complete_episodes) -> Build the
            # MultiAgentBatch from a single episode and add it to "outputs".
            # Otherwise, just postprocess and continue collecting across
            # episodes.
            ma_sample_batch = sample_collector.postprocess_episode(
                episode,
                is_done=is_done or (hit_horizon and not soft_horizon),
                check_dones=check_dones,
                build=not multiple_episodes_in_batch)
            if ma_sample_batch:
                outputs.append(ma_sample_batch)

            # Call each (in-memory) policy's Exploration.on_episode_end
            # method.
            # Note: This may break the exploration (e.g. ParameterNoise) of
            # policies in the `policy_map` that have not been recently used
            # (and are therefore stashed to disk). However, we certainly do not
            # want to loop through all (even stashed) policies here as that
            # would counter the purpose of the LRU policy caching.
            for p in worker.policy_map.cache.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(policy=p,
                                                 environment=base_env,
                                                 episode=episode,
                                                 tf_sess=p.get_session())
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=worker.policy_map,
                episode=episode,
                env_index=env_id,
            )
            # Horizon hit and we have a soft horizon (no hard env reset).
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID,
                                   EnvObsType] = base_env.try_reset(env_id)
            # Reset not supported, drop this env from the ready list.
            if resetted_obs is None:
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            # Creates a new episode if this is not async return.
            # If reset is async, we will get its result in some future poll.
            elif resetted_obs != ASYNC_RESET_RETURN:
                new_episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=worker.policy_map,
                        episode=new_episode)
                # types: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = new_episode.policy_for(agent_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        worker.preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        worker.filters, policy_id)(prep_obs)
                    new_episode._set_last_observation(agent_id, filtered_obs)

                    # Add initial obs to buffer.
                    sample_collector.add_init_obs(new_episode, agent_id,
                                                  env_id, policy_id,
                                                  new_episode.length - 1,
                                                  filtered_obs)

                    item = PolicyEvalData(
                        env_id, agent_id, filtered_obs,
                        episode.last_info_for(agent_id) or {},
                        episode.rnn_state_for(agent_id), None, 0.0)
                    to_eval[policy_id].append(item)

    # Try to build something.
    if multiple_episodes_in_batch:
        sample_batches = \
            sample_collector.try_build_truncated_episode_multi_agent_batch()
        if sample_batches:
            outputs.extend(sample_batches)

    return active_envs, to_eval, outputs
Ejemplo n.º 14
0
Archivo: sampler.py Proyecto: Yard1/ray
def _env_runner(
    worker: "RolloutWorker",
    base_env: BaseEnv,
    extra_batch_callback: Callable[[SampleBatchType], None],
    horizon: int,
    normalize_actions: bool,
    clip_actions: bool,
    multiple_episodes_in_batch: bool,
    callbacks: "DefaultCallbacks",
    perf_stats: _PerfStats,
    soft_horizon: bool,
    no_done_at_end: bool,
    observation_fn: "ObservationFunction",
    sample_collector: Optional[SampleCollector] = None,
    render: bool = None,
) -> Iterable[SampleBatchType]:
    """This implements the common experience collection logic.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        extra_batch_callback (fn): function to send extra batch data to.
        horizon (int): Horizon of the episode.
        multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        normalize_actions (bool): Whether to normalize actions to the action
            space's bounds.
        clip_actions (bool): Whether to clip actions to the space range.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        perf_stats (_PerfStats): Record perf stats into this object.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        sample_collector (Optional[SampleCollector]): An optional
            SampleCollector object to use.
        render (bool): Whether to try to render the environment after each
            step.

    Yields:
        rollout (SampleBatch): Object containing state, action, reward,
            terminal condition, and other fields as dictated by `policy`.
    """

    # May be populated with used for image rendering
    simple_image_viewer: Optional["SimpleImageViewer"] = None

    # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
    # error and continue with max_episode_steps=None.
    max_episode_steps = None
    try:
        max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
    except Exception:
        pass

    # Trainer has a given `horizon` setting.
    if horizon:
        # `horizon` is larger than env's limit.
        if max_episode_steps and horizon > max_episode_steps:
            # Try to override the env's own max-step setting with our horizon.
            # If this won't work, throw an error.
            try:
                base_env.get_unwrapped()[0].spec.max_episode_steps = horizon
                base_env.get_unwrapped()[0]._max_episode_steps = horizon
            except Exception:
                raise ValueError(
                    "Your `horizon` setting ({}) is larger than the Env's own "
                    "timestep limit ({}), which seems to be unsettable! Try "
                    "to increase the Env's built-in limit to be at least as "
                    "large as your wanted `horizon`.".format(
                        horizon, max_episode_steps))
    # Otherwise, set Trainer's horizon to env's max-steps.
    elif max_episode_steps:
        horizon = max_episode_steps
        logger.debug(
            "No episode horizon specified, setting it to Env's limit ({}).".
            format(max_episode_steps))
    # No horizon/max_episode_steps -> Episodes may be infinitely long.
    else:
        horizon = float("inf")
        logger.debug("No episode horizon specified, assuming inf.")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool: List[MultiAgentSampleBatchBuilder] = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            return None

    def new_episode(env_id):
        episode = MultiAgentEpisode(worker.policy_map,
                                    worker.policy_mapping_fn,
                                    get_batch_builder,
                                    extra_batch_callback,
                                    env_id=env_id)
        # Call each policy's Exploration.on_episode_start method.
        # Note: This may break the exploration (e.g. ParameterNoise) of
        # policies in the `policy_map` that have not been recently used
        # (and are therefore stashed to disk). However, we certainly do not
        # want to loop through all (even stashed) policies here as that
        # would counter the purpose of the LRU policy caching.
        for p in worker.policy_map.cache.values():
            if getattr(p, "exploration", None) is not None:
                p.exploration.on_episode_start(policy=p,
                                               environment=base_env,
                                               episode=episode,
                                               tf_sess=p.get_session())
        callbacks.on_episode_start(
            worker=worker,
            base_env=base_env,
            policies=worker.policy_map,
            episode=episode,
            env_index=env_id,
        )
        return episode

    active_episodes: Dict[str, MultiAgentEpisode] = \
        NewEpisodeDefaultDict(new_episode)

    while True:
        perf_stats.iters += 1
        t0 = time.time()
        # Get observations from all ready agents.
        # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
        unfiltered_obs, rewards, dones, infos, off_policy_actions = \
            base_env.poll()
        perf_stats.env_wait_time += time.time() - t0

        if log_once("env_returns"):
            logger.info("Raw obs from env: {}".format(
                summarize(unfiltered_obs)))
            logger.info("Info return from env: {}".format(summarize(infos)))

        # Process observations and prepare for policy evaluation.
        t1 = time.time()
        # types: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
        #       List[Union[RolloutMetrics, SampleBatchType]]
        active_envs, to_eval, outputs = \
            _process_observations(
                worker=worker,
                base_env=base_env,
                active_episodes=active_episodes,
                unfiltered_obs=unfiltered_obs,
                rewards=rewards,
                dones=dones,
                infos=infos,
                horizon=horizon,
                multiple_episodes_in_batch=multiple_episodes_in_batch,
                callbacks=callbacks,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn,
                sample_collector=sample_collector,
            )
        perf_stats.raw_obs_processing_time += time.time() - t1
        for o in outputs:
            yield o

        # Do batched policy eval (accross vectorized envs).
        t2 = time.time()
        # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
        eval_results = _do_policy_eval(
            to_eval=to_eval,
            policies=worker.policy_map,
            policy_mapping_fn=worker.policy_mapping_fn,
            sample_collector=sample_collector,
            active_episodes=active_episodes,
        )
        perf_stats.inference_time += time.time() - t2

        # Process results and update episode state.
        t3 = time.time()
        actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
            _process_policy_eval_results(
                to_eval=to_eval,
                eval_results=eval_results,
                active_episodes=active_episodes,
                active_envs=active_envs,
                off_policy_actions=off_policy_actions,
                policies=worker.policy_map,
                normalize_actions=normalize_actions,
                clip_actions=clip_actions,
            )
        perf_stats.action_processing_time += time.time() - t3

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        t4 = time.time()
        base_env.send_actions(actions_to_send)
        perf_stats.env_wait_time += time.time() - t4

        # Try to render the env, if required.
        if render:
            t5 = time.time()
            # Render can either return an RGB image (uint8 [w x h x 3] numpy
            # array) or take care of rendering itself (returning True).
            rendered = base_env.try_render()
            # Rendering returned an image -> Display it in a SimpleImageViewer.
            if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
                # ImageViewer not defined yet, try to create one.
                if simple_image_viewer is None:
                    try:
                        from gym.envs.classic_control.rendering import \
                            SimpleImageViewer
                        simple_image_viewer = SimpleImageViewer()
                    except (ImportError, ModuleNotFoundError):
                        render = False  # disable rendering
                        logger.warning(
                            "Could not import gym.envs.classic_control."
                            "rendering! Try `pip install gym[all]`.")
                if simple_image_viewer:
                    simple_image_viewer.imshow(rendered)
            elif rendered not in [True, False, None]:
                raise ValueError(
                    "The env's ({base_env}) `try_render()` method returned an"
                    " unsupported value! Make sure you either return a "
                    "uint8/w x h x 3 (RGB) image or handle rendering in a "
                    "window and then return `True`.")
            perf_stats.env_render_time += time.time() - t5
Ejemplo n.º 15
0
    def _initialize_loss_from_dummy_batch(
            self, auto_remove_unneeded_view_reqs: bool = True) -> None:
        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        # Fields that have not been accessed are not needed for action
        # computations -> Tag them as `used_for_compute_actions=False`.
        for key, view_req in self.view_requirements.items():
            if (not key.startswith("state_in_")
                    and key not in self._input_dict.accessed_keys):
                view_req.used_for_compute_actions = False
        for key, value in self.extra_action_out_fn().items():
            self._dummy_batch[key] = get_dummy_batch_for_space(
                gym.spaces.Box(-1.0,
                               1.0,
                               shape=value.shape.as_list()[1:],
                               dtype=value.dtype.name),
                batch_size=len(self._dummy_batch),
            )
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info(
                    "Adding extra-action-fetch `{}` to view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype.name),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(
                        -1.0,
                        1.0,
                        shape=dummy_batch[key].shape[1:],
                        dtype=dummy_batch[key].dtype,
                    ),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict),
            _is_training=True,
        )

        if self._state_inputs:
            train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
            self._loss_input_dict.update(
                {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        losses = self._do_loss_init(train_batch)

        all_accessed_keys = (train_batch.accessed_keys
                             | dummy_batch.accessed_keys
                             | dummy_batch.added_keys
                             | set(self.model.view_requirements.keys()))

        TFPolicy._initialize_loss(
            self,
            losses,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] + ([
                 (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])
             ] if SampleBatch.SEQ_LENS in train_batch else []),
        )

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        self._stats_fetches.update(self.grad_stats_fn(train_batch,
                                                      self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if (key not in train_batch.accessed_keys
                        and key not in self.model.view_requirements
                        and key not in [
                            SampleBatch.EPS_ID,
                            SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID,
                            SampleBatch.DONES,
                            SampleBatch.REWARDS,
                            SampleBatch.INFOS,
                            SampleBatch.T,
                            SampleBatch.OBS_EMBEDS,
                        ]):
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if (key not in all_accessed_keys and key not in [
                        SampleBatch.EPS_ID,
                        SampleBatch.AGENT_INDEX,
                        SampleBatch.UNROLL_ID,
                        SampleBatch.DONES,
                        SampleBatch.REWARDS,
                        SampleBatch.INFOS,
                        SampleBatch.T,
                ] and key not in self.model.view_requirements):
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    # If we are not writing output to disk, safe to erase
                    # this key to save space in the sample batch.
                    elif self.config["output"] is None:
                        del self.view_requirements[key]

                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }
Ejemplo n.º 16
0
    def postprocess_batch_so_far(self, episode=None):
        """Apply policy postprocessors to any unprocessed rows.

        This pushes the postprocessed per-agent batches onto the per-policy
        builders, clearing per-agent state.

        Args:
            episode (Optional[MultiAgentEpisode]): The Episode object that
                holds this MultiAgentBatchBuilder object.
        """

        # Materialize the batches so far.
        pre_batches = {}
        for agent_id, builder in self.agent_builders.items():
            pre_batches[agent_id] = (
                self.policy_map[self.agent_to_policy[agent_id]],
                builder.build_and_reset())

        # Apply postprocessor.
        post_batches = {}
        if self.clip_rewards is True:
            for _, (_, pre_batch) in pre_batches.items():
                pre_batch["rewards"] = np.sign(pre_batch["rewards"])
        elif self.clip_rewards:
            for _, (_, pre_batch) in pre_batches.items():
                pre_batch["rewards"] = np.clip(pre_batch["rewards"],
                                               a_min=-self.clip_rewards,
                                               a_max=self.clip_rewards)
        for agent_id, (_, pre_batch) in pre_batches.items():
            other_batches = pre_batches.copy()
            del other_batches[agent_id]
            policy = self.policy_map[self.agent_to_policy[agent_id]]
            if any(pre_batch["dones"][:-1]) or len(set(
                    pre_batch["eps_id"])) > 1:
                raise ValueError(
                    "Batches sent to postprocessing must only contain steps "
                    "from a single trajectory.", pre_batch)
            post_batches[agent_id] = policy.postprocess_trajectory(
                pre_batch, other_batches, episode)
            # Call the Policy's Exploration's postprocess method.
            if getattr(policy, "exploration", None) is not None:
                policy.exploration.postprocess_trajectory(
                    policy, post_batches[agent_id],
                    getattr(policy, "_sess", None))

        if log_once("after_post"):
            logger.info(
                "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
                format(summarize(post_batches)))

        # Append into policy batches and reset
        from ray.rllib.evaluation.rollout_worker import get_global_worker
        for agent_id, post_batch in sorted(post_batches.items()):
            self.callbacks.on_postprocess_trajectory(
                worker=get_global_worker(),
                episode=episode,
                agent_id=agent_id,
                policy_id=self.agent_to_policy[agent_id],
                policies=self.policy_map,
                postprocessed_batch=post_batch,
                original_batches=pre_batches)
            self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
                post_batch)

        self.agent_builders.clear()
        self.agent_to_policy.clear()
Ejemplo n.º 17
0
    def _get_loss_inputs_dict(self, batch, shuffle):
        """Return a feed dict from a batch.

        Arguments:
            batch (SampleBatch): batch of data to derive inputs from
            shuffle (bool): whether to shuffle batch sequences. Shuffle may
                be done in-place. This only makes sense if you're further
                applying minibatch SGD after getting the outputs.

        Returns:
            feed dict of data
        """

        feed_dict = {}
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True

        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            if shuffle:
                batch.shuffle()
            for k, ph in self._loss_inputs:
                feed_dict[ph] = batch[k]
            return feed_dict

        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False

        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens

        if log_once("rnn_feed_dict"):
            logger.info("Padded input for RNN:\n\n{}\n".format(
                summarize({
                    "features": feature_sequences,
                    "initial_states": initial_states,
                    "seq_lens": seq_lens,
                    "max_seq_len": max_seq_len,
                })))
        return feed_dict
Ejemplo n.º 18
0
    def _initialize_loss(self, loss: TensorType,
                         loss_inputs: List[Tuple[str, TensorType]]) -> None:
        """Initializes the loss op from given loss tensor and placeholders.

        Args:
            loss (TensorType): The loss op generated by some loss function.
            loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
                (name, tf1.placeholders) needed for calculating the loss.
        """
        self._loss_inputs = loss_inputs
        self._loss_input_dict = dict(self._loss_inputs)
        for i, ph in enumerate(self._state_inputs):
            self._loss_input_dict["state_in_{}".format(i)] = ph

        if self.model:
            self._loss = self.model.custom_loss(loss, self._loss_input_dict)
            self._stats_fetches.update({
                "model":
                self.model.metrics() if isinstance(self.model, ModelV2) else
                self.model.custom_stats()
            })
        else:
            self._loss = loss

        self._optimizer = self.optimizer()
        self._grads_and_vars = [
            (g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
            if g is not None
        ]
        self._grads = [g for (g, v) in self._grads_and_vars]

        # TODO(sven/ekl): Deprecate support for v1 models.
        if hasattr(self, "model") and isinstance(self.model, ModelV2):
            self._variables = ray.experimental.tf_utils.TensorFlowVariables(
                [], self._sess, self.variables())
        else:
            self._variables = ray.experimental.tf_utils.TensorFlowVariables(
                self._loss, self._sess)

        # gather update ops for any batch norm layers
        if not self._update_ops:
            self._update_ops = tf1.get_collection(
                tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name)
        if self._update_ops:
            logger.info("Update ops to run on apply gradient: {}".format(
                self._update_ops))
        with tf1.control_dependencies(self._update_ops):
            self._apply_op = self.build_apply_op(self._optimizer,
                                                 self._grads_and_vars)

        if log_once("loss_used"):
            logger.debug(
                "These tensors were used in the loss_fn:\n\n{}\n".format(
                    summarize(self._loss_input_dict)))

        self._sess.run(tf1.global_variables_initializer())
        self._optimizer_variables = None
        if self._optimizer:
            self._optimizer_variables = \
                ray.experimental.tf_utils.TensorFlowVariables(
                    self._optimizer.variables(), self._sess)
Ejemplo n.º 19
0
def _process_observations(worker, base_env, policies, batch_builder_pool,
                          active_episodes, unfiltered_obs, rewards, dones,
                          infos, off_policy_actions, horizon, preprocessors,
                          obs_filters, rollout_fragment_length, pack,
                          callbacks, soft_horizon, no_done_at_end,
                          observation_fn):
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []
    large_batch_threshold = max(1000, rollout_fragment_length * 10) if \
        rollout_fragment_length != float("inf") else 5000

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > large_batch_threshold
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly and that"
                " it terminates at some point. "
                "Note: In multi-agent environments, `rollout_fragment_length` "
                "sets the batch size based on environment steps, not the "
                "steps of "
                "individual agents, which can result in unexpectedly large "
                "batches. Also, you may be in evaluation waiting for your Env "
                "to terminate (batch_mode=`complete_episodes`). Make sure it "
                "does at some point.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data))
        else:
            hit_horizon = False
            all_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            agent_obs = observation_fn(agent_obs=agent_obs,
                                       worker=worker,
                                       base_env=base_env,
                                       policies=policies,
                                       episode=episode)
            if not isinstance(agent_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        for agent_id, raw_obs in agent_obs.items():
            assert agent_id != "__all__"
            policy_id = episode.policy_for(agent_id)
            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if
                           (no_done_at_end or
                            (hit_horizon and soft_horizon)) else agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        callbacks.on_episode_step(worker=worker,
                                  base_env=base_env,
                                  episode=episode)

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_agent_data():
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= rollout_fragment_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(policy=p,
                                                 environment=base_env,
                                                 episode=episode,
                                                 tf_sess=getattr(
                                                     p, "_sess", None))
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(worker=worker,
                                     base_env=base_env,
                                     policies=policies,
                                     episode=episode)
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return
                # If reset is async, we will get its result in some future poll
                episode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs = observation_fn(agent_obs=resetted_obs,
                                                  worker=worker,
                                                  base_env=base_env,
                                                  policies=policies,
                                                  episode=episode)
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id = episode.policy_for(agent_id)
                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                flatten_to_single_ndarray(
                                    policy.action_space.sample())), 0.0))

    return active_envs, to_eval, outputs
Ejemplo n.º 20
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape = [s if s is not None else 1 for s in shape]
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS: fake_array(self._obs_input),
            SampleBatch.NEXT_OBS: fake_array(self._obs_input),
            SampleBatch.DONES: np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS: fake_array(
                ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS: np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
            state_batches.append(np.expand_dims(h, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS: self._prev_action_input,
                SampleBatch.PREV_REWARDS: self._prev_reward_input,
                SampleBatch.CUR_OBS: self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({
                SampleBatch.CUR_OBS: self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict = train_batch
        loss = self._do_loss_init(train_batch)
        for k in sorted(train_batch.accessed_keys):
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))
        self._sess.run(tf.global_variables_initializer())
Ejemplo n.º 21
0
def _env_runner(worker: "RolloutWorker",
                base_env: BaseEnv,
                extra_batch_callback: Callable[[SampleBatchType], None],
                policies: Dict[PolicyID, Policy],
                policy_mapping_fn: Callable[[AgentID], PolicyID],
                rollout_fragment_length: int,
                horizon: int,
                preprocessors: Dict[PolicyID, Preprocessor],
                obs_filters: Dict[PolicyID, Filter],
                clip_rewards: bool,
                clip_actions: bool,
                pack_multiple_episodes_in_batch: bool,
                callbacks: "DefaultCallbacks",
                tf_sess: Optional["tf.Session"],
                perf_stats: _PerfStats,
                soft_horizon: bool,
                no_done_at_end: bool,
                observation_fn: "ObservationFunction",
                _use_trajectory_view_api: bool = False
                ) -> Iterable[SampleBatchType]:
    """This implements the common experience collection logic.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        extra_batch_callback (fn): function to send extra batch data to.
        policies (Dict[PolicyID, Policy]): Map of policy ids to Policy
            instances.
        policy_mapping_fn (func): Function that maps agent ids to policy ids.
            This is called when an agent first enters the environment. The
            agent is then "bound" to the returned policy for the episode.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        clip_rewards (bool): Whether to clip rewards before postprocessing.
        pack_multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        clip_actions (bool): Whether to clip actions to the space range.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        tf_sess (Session|None): Optional tensorflow session to use for batching
            TF policy evaluations.
        perf_stats (_PerfStats): Record perf stats into this object.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        _use_trajectory_view_api (bool): Whether to use the (experimental)
            `_use_trajectory_view_api` to make generic trajectory views
            available to Models. Default: False.

    Yields:
        rollout (SampleBatch): Object containing state, action, reward,
            terminal condition, and other fields as dictated by `policy`.
    """

    # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
    # error and continue with max_episode_steps=None.
    max_episode_steps = None
    try:
        max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
    except Exception:
        pass

    # Trainer has a given `horizon` setting.
    if horizon:
        # `horizon` is larger than env's limit -> Error and explain how
        # to increase Env's own episode limit.
        if max_episode_steps and horizon > max_episode_steps:
            raise ValueError(
                "Your `horizon` setting ({}) is larger than the Env's own "
                "timestep limit ({})! Try to increase the Env's limit via "
                "setting its `spec.max_episode_steps` property.".format(
                    horizon, max_episode_steps))
    # Otherwise, set Trainer's horizon to env's max-steps.
    elif max_episode_steps:
        horizon = max_episode_steps
        logger.debug(
            "No episode horizon specified, setting it to Env's limit ({}).".
            format(max_episode_steps))
    else:
        horizon = float("inf")
        logger.debug("No episode horizon specified, assuming inf.")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool: List[MultiAgentSampleBatchBuilder] = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            return MultiAgentSampleBatchBuilder(policies, clip_rewards,
                                                callbacks)

    def new_episode():
        episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                    get_batch_builder, extra_batch_callback)
        # Call each policy's Exploration.on_episode_start method.
        # type: Policy
        for p in policies.values():
            if getattr(p, "exploration", None) is not None:
                p.exploration.on_episode_start(
                    policy=p,
                    environment=base_env,
                    episode=episode,
                    tf_sess=getattr(p, "_sess", None))
        callbacks.on_episode_start(
            worker=worker,
            base_env=base_env,
            policies=policies,
            episode=episode)
        return episode

    active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode)

    while True:
        perf_stats.iters += 1
        t0 = time.time()
        # Get observations from all ready agents.
        # type: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
        unfiltered_obs, rewards, dones, infos, off_policy_actions = \
            base_env.poll()
        perf_stats.env_wait_time += time.time() - t0

        if log_once("env_returns"):
            logger.info("Raw obs from env: {}".format(
                summarize(unfiltered_obs)))
            logger.info("Info return from env: {}".format(summarize(infos)))

        # Process observations and prepare for policy evaluation.
        t1 = time.time()
        # type: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
        #       List[Union[RolloutMetrics, SampleBatchType]]
        active_envs, to_eval, outputs = _process_observations(
            worker=worker,
            base_env=base_env,
            policies=policies,
            batch_builder_pool=batch_builder_pool,
            active_episodes=active_episodes,
            unfiltered_obs=unfiltered_obs,
            rewards=rewards,
            dones=dones,
            infos=infos,
            horizon=horizon,
            preprocessors=preprocessors,
            obs_filters=obs_filters,
            rollout_fragment_length=rollout_fragment_length,
            pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch,
            callbacks=callbacks,
            soft_horizon=soft_horizon,
            no_done_at_end=no_done_at_end,
            observation_fn=observation_fn,
            _use_trajectory_view_api=_use_trajectory_view_api)
        perf_stats.processing_time += time.time() - t1
        for o in outputs:
            yield o

        # Do batched policy eval (accross vectorized envs).
        t2 = time.time()
        # type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
        eval_results = _do_policy_eval(
            to_eval=to_eval,
            policies=policies,
            active_episodes=active_episodes,
            tf_sess=tf_sess,
            _use_trajectory_view_api=_use_trajectory_view_api)
        perf_stats.inference_time += time.time() - t2

        # Process results and update episode state.
        t3 = time.time()
        actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
            _process_policy_eval_results(
                to_eval=to_eval,
                eval_results=eval_results,
                active_episodes=active_episodes,
                active_envs=active_envs,
                off_policy_actions=off_policy_actions,
                policies=policies,
                clip_actions=clip_actions,
                _use_trajectory_view_api=_use_trajectory_view_api)
        perf_stats.processing_time += time.time() - t3

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        t4 = time.time()
        base_env.send_actions(actions_to_send)
        perf_stats.env_wait_time += time.time() - t4
Ejemplo n.º 22
0
        def _compute_gradients_helper(self, samples):
            """Computes and returns grads as eager tensors."""

            # Increase the tracing counter to make sure we don't re-trace too
            # often. If eager_tracing=True, this counter should only get
            # incremented during the @tf.function trace operations, never when
            # calling the already traced function after that.
            self._re_trace_counter += 1

            # Gather all variables for which to calculate losses.
            if isinstance(self.model, tf.keras.Model):
                variables = self.model.trainable_variables
            else:
                variables = self.model.trainable_variables()

            # Calculate the loss(es) inside a tf GradientTape.
            with tf.GradientTape(
                    persistent=compute_gradients_fn is not None) as tape:
                losses = self._loss(self, self.model, self.dist_class, samples)
            losses = force_list(losses)

            # User provided a compute_gradients_fn.
            if compute_gradients_fn:
                # Wrap our tape inside a wrapper, such that the resulting
                # object looks like a "classic" tf.optimizer. This way, custom
                # compute_gradients_fn will work on both tf static graph
                # and tf-eager.
                optimizer = OptimizerWrapper(tape)
                # More than one loss terms/optimizers.
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    grads_and_vars = compute_gradients_fn(
                        self, [optimizer] * len(losses), losses)
                # Only one loss and one optimizer.
                else:
                    grads_and_vars = [
                        compute_gradients_fn(self, optimizer, losses[0])
                    ]
            # Default: Compute gradients using the above tape.
            else:
                grads_and_vars = [
                    list(zip(tape.gradient(loss, variables), variables))
                    for loss in losses
                ]

            if log_once("grad_vars"):
                for g_and_v in grads_and_vars:
                    for g, v in g_and_v:
                        if g is not None:
                            logger.info(f"Optimizing variable {v.name}")

            # `grads_and_vars` is returned a list (len=num optimizers/losses)
            # of lists of (grad, var) tuples.
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
            # `grads_and_vars` is returned as a list of (grad, var) tuples.
            else:
                grads_and_vars = grads_and_vars[0]
                grads = [g for g, _ in grads_and_vars]

            stats = self._stats(self, samples, grads)
            return grads_and_vars, grads, stats
Ejemplo n.º 23
0
def _do_policy_eval(
        *,
        to_eval: Dict[PolicyID, List[PolicyEvalData]],
        policies: Dict[PolicyID, Policy],
        active_episodes: Dict[str, MultiAgentEpisode],
        tf_sess=None,
        _use_trajectory_view_api=False
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
    """Call compute_actions on collected episode/model data to get next action.

    Args:
        to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
            IDs to lists of PolicyEvalData objects (items in these lists will
            be the batch's items for the model forward pass).
        policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
            obj.
        active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
            batching TF policy evaluations.
        _use_trajectory_view_api (bool): Whether to use the (experimental)
            `_use_trajectory_view_api` procedure to collect samples.
            Default: False.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results: Dict[PolicyID, TensorStructType] = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches: Dict[PolicyID, Any] = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    # type: PolicyID, PolicyEvalData
    for policy_id, eval_data in to_eval.items():
        rnn_in: List[List[Any]] = [t.rnn_state for t in eval_data]
        policy: Policy = _get_or_raise(policies, policy_id)
        # If tf (non eager) AND TFPolicy's compute_action method has not been
        # overridden -> Use `policy._build_compute_actions()`.
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):

            obs_batch: List[EnvObsType] = [t.obs for t in eval_data]
            state_batches: StateBatch = _to_column_format(rnn_in)
            # TODO(ekl): how can we make info batch available to TF code?
            prev_action_batch = [t.prev_action for t in eval_data]
            prev_reward_batch = [t.prev_reward for t in eval_data]

            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=obs_batch,
                state_batches=state_batches,
                prev_action_batch=prev_action_batch,
                prev_reward_batch=prev_reward_batch,
                timestep=policy.global_timestep)
        else:
            rnn_in_cols: StateBatch = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        # type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Ejemplo n.º 24
0
def _bootstrap_config(config: Dict[str, Any],
                      no_config_cache: bool = False) -> Dict[str, Any]:
    config = prepare_config(config)

    hasher = hashlib.sha1()
    hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
    cache_key = os.path.join(tempfile.gettempdir(),
                             "ray-config-{}".format(hasher.hexdigest()))

    if os.path.exists(cache_key) and not no_config_cache:
        cli_logger.old_info(logger, "Using cached config at {}", cache_key)

        config_cache = json.loads(open(cache_key).read())
        if config_cache.get("_version", -1) == CONFIG_CACHE_VERSION:
            # todo: is it fine to re-resolve? afaik it should be.
            # we can have migrations otherwise or something
            # but this seems overcomplicated given that resolving is
            # relatively cheap
            try_reload_log_state(config_cache["config"]["provider"],
                                 config_cache.get("provider_log_info"))

            if log_once("_printed_cached_config_warning"):
                cli_logger.verbose_warning(
                    "Loaded cached provider configuration "
                    "from " + cf.bold("{}"), cache_key)
                if cli_logger.verbosity == 0:
                    cli_logger.warning("Loaded cached provider configuration")
                cli_logger.warning(
                    "If you experience issues with "
                    "the cloud provider, try re-running "
                    "the command with {}.", cf.bold("--no-config-cache"))

            return config_cache["config"]
        else:
            cli_logger.warning(
                "Found cached cluster config "
                "but the version " + cf.bold("{}") + " "
                "(expected " + cf.bold("{}") + ") does not match.\n"
                "This is normal if cluster launcher was updated.\n"
                "Config will be re-resolved.",
                config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)
    validate_config(config)

    importer = _NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    provider_cls = importer(config["provider"])

    cli_logger.print("Checking {} environment settings",
                     _PROVIDER_PRETTY_NAMES.get(config["provider"]["type"]))
    resolved_config = provider_cls.bootstrap_config(config)

    if not no_config_cache:
        with open(cache_key, "w") as f:
            config_cache = {
                "_version": CONFIG_CACHE_VERSION,
                "provider_log_info": try_get_log_state(config["provider"]),
                "config": resolved_config
            }
            f.write(json.dumps(config_cache))
    return resolved_config
Ejemplo n.º 25
0
    def sample(self) -> SampleBatchType:
        """Returns a batch of experience sampled from this worker.

        This method must be implemented by subclasses.

        Returns:
            SampleBatchType: A columnar batch of experiences (e.g., tensors).

        Examples:
            >>> print(worker.sample())
            SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
        """

        if self.fake_sampler and self.last_batch is not None:
            return self.last_batch

        if log_once("sample_start"):
            logger.info("Generating sample batch of size {}".format(
                self.rollout_fragment_length))

        batches = [self.input_reader.next()]
        steps_so_far = batches[0].count

        # In truncate_episodes mode, never pull more than 1 batch per env.
        # This avoids over-running the target batch size.
        if self.batch_mode == "truncate_episodes":
            max_batches = self.num_envs
        else:
            max_batches = float("inf")

        while (steps_so_far < self.rollout_fragment_length
               and len(batches) < max_batches):
            batch = self.input_reader.next()
            steps_so_far += batch.count
            batches.append(batch)
        batch = batches[0].concat_samples(batches)

        self.callbacks.on_sample_end(worker=self, samples=batch)

        # Always do writes prior to compression for consistency and to allow
        # for better compression inside the writer.
        self.output_writer.write(batch)

        # Do off-policy estimation if needed
        if self.reward_estimators:
            for sub_batch in batch.split_by_episode():
                for estimator in self.reward_estimators:
                    estimator.process(sub_batch)

        if log_once("sample_end"):
            logger.info("Completed sample batch:\n\n{}\n".format(
                summarize(batch)))

        if self.compress_observations == "bulk":
            batch.compress(bulk=True)
        elif self.compress_observations:
            batch.compress()

        if self.fake_sampler:
            self.last_batch = batch
        return batch
Ejemplo n.º 26
0
Archivo: sampler.py Proyecto: w0617/ray
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
                rollout_fragment_length, horizon, preprocessors, obs_filters,
                clip_rewards, clip_actions, pack, callbacks, tf_sess,
                perf_stats, soft_horizon, no_done_at_end):
    """This implements the common experience collection logic.

    Args:
        base_env (BaseEnv): env implementing BaseEnv.
        extra_batch_callback (fn): function to send extra batch data to.
        policies (dict): Map of policy ids to Policy instances.
        policy_mapping_fn (func): Function that maps agent ids to policy ids.
            This is called when an agent first enters the environment. The
            agent is then "bound" to the returned policy for the episode.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        clip_rewards (bool): Whether to clip rewards before postprocessing.
        pack (bool): Whether to pack multiple episodes into each batch. This
            guarantees batches will be exactly `rollout_fragment_length` in
            size.
        clip_actions (bool): Whether to clip actions to the space range.
        callbacks (dict): User callbacks to run on episode events.
        tf_sess (Session|None): Optional tensorflow session to use for batching
            TF policy evaluations.
        perf_stats (PerfStats): Record perf stats into this object.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.

    Yields:
        rollout (SampleBatch): Object containing state, action, reward,
            terminal condition, and other fields as dictated by `policy`.
    """

    # Try to get Env's max_episode_steps prop. If it doesn't exist, catch
    # error and continue.
    max_episode_steps = None
    try:
        max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
    except Exception:
        pass

    # Trainer has a given `horizon` setting.
    if horizon:
        # `horizon` is larger than env's limit -> Error and explain how
        # to increase Env's own episode limit.
        if max_episode_steps and horizon > max_episode_steps:
            raise ValueError(
                "Your `horizon` setting ({}) is larger than the Env's own "
                "timestep limit ({})! Try to increase the Env's limit via "
                "setting its `spec.max_episode_steps` property.".format(
                    horizon, max_episode_steps))
    # Otherwise, set Trainer's horizon to env's max-steps.
    elif max_episode_steps:
        horizon = max_episode_steps
        logger.debug(
            "No episode horizon specified, setting it to Env's limit ({}).".
            format(max_episode_steps))
    else:
        horizon = float("inf")
        logger.debug("No episode horizon specified, assuming inf.")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            return MultiAgentSampleBatchBuilder(
                policies, clip_rewards, callbacks.get("on_postprocess_traj"))

    def new_episode():
        episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                    get_batch_builder, extra_batch_callback)
        if callbacks.get("on_episode_start"):
            callbacks["on_episode_start"]({
                "env": base_env,
                "policy": policies,
                "episode": episode,
            })
        return episode

    active_episodes = defaultdict(new_episode)

    while True:
        perf_stats.iters += 1
        t0 = time.time()
        # Get observations from all ready agents
        unfiltered_obs, rewards, dones, infos, off_policy_actions = \
            base_env.poll()
        perf_stats.env_wait_time += time.time() - t0

        if log_once("env_returns"):
            logger.info("Raw obs from env: {}".format(
                summarize(unfiltered_obs)))
            logger.info("Info return from env: {}".format(summarize(infos)))

        # Process observations and prepare for policy evaluation
        t1 = time.time()
        active_envs, to_eval, outputs = _process_observations(
            base_env, policies, batch_builder_pool, active_episodes,
            unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
            preprocessors, obs_filters, rollout_fragment_length, pack,
            callbacks, soft_horizon, no_done_at_end)
        perf_stats.processing_time += time.time() - t1
        for o in outputs:
            yield o

        # Do batched policy eval
        t2 = time.time()
        eval_results = _do_policy_eval(tf_sess, to_eval, policies,
                                       active_episodes)
        perf_stats.inference_time += time.time() - t2

        # Process results and update episode state
        t3 = time.time()
        actions_to_send = _process_policy_eval_results(to_eval, eval_results,
                                                       active_episodes,
                                                       active_envs,
                                                       off_policy_actions,
                                                       policies, clip_actions)
        perf_stats.processing_time += time.time() - t3

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        t4 = time.time()
        base_env.send_actions(actions_to_send)
        perf_stats.env_wait_time += time.time() - t4
def wrap_function(train_func, warn=True):
    if hasattr(train_func, "__mixins__"):
        inherit_from = train_func.__mixins__ + (FunctionRunner, )
    else:
        inherit_from = (FunctionRunner, )

    func_args = inspect.getfullargspec(train_func).args
    use_checkpoint = detect_checkpoint_function(train_func)
    use_config_single = detect_config_single(train_func)
    use_reporter = detect_reporter(train_func)

    if not any([use_checkpoint, use_config_single, use_reporter]):
        # use_reporter is hidden
        raise ValueError(
            "Unknown argument found in the Trainable function. "
            "The function args must include a 'config' positional "
            "parameter. Any other args must be 'checkpoint_dir'. "
            "Found: {}".format(func_args))

    if use_config_single and not use_checkpoint:
        if log_once("tune_function_checkpoint") and warn:
            logger.warning(
                "Function checkpointing is disabled. This may result in "
                "unexpected behavior when using checkpointing features or "
                "certain schedulers. To enable, set the train function "
                "arguments to be `func(config, checkpoint_dir=None)`.")

    class ImplicitFunc(*inherit_from):
        _name = train_func.__name__ if hasattr(train_func, "__name__") \
            else "func"

        def _trainable_func(self, config, reporter, checkpoint_dir):
            if not use_checkpoint and not use_reporter:
                fn = partial(train_func, config)
            elif use_checkpoint:
                fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
            else:
                fn = partial(train_func, config, reporter)

            def handle_output(output):
                if not output:
                    return
                elif isinstance(output, dict):
                    reporter(**output)
                elif isinstance(output, Number):
                    reporter(_metric=output)
                else:
                    raise ValueError(
                        "Invalid return or yield value. Either return/yield "
                        "a single number or a dictionary object in your "
                        "trainable function.")

            output = None
            if inspect.isgeneratorfunction(train_func):
                for output in fn():
                    handle_output(output)
            else:
                output = fn()
                handle_output(output)

            # If train_func returns, we need to notify the main event loop
            # of the last result while avoiding double logging. This is done
            # with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
            reporter(**{RESULT_DUPLICATE: True})
            return output

    return ImplicitFunc
Ejemplo n.º 28
0
    def postprocess_episode(
            self,
            episode: MultiAgentEpisode,
            is_done: bool = False,
            check_dones: bool = False,
            build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
        episode_id = episode.episode_id
        policy_collector_group = episode.batch_builder

        # TODO: (sven) Once we implement multi-agent communication channels,
        #  we have to resolve the restriction of only sending other agent
        #  batches from the same policy to the postprocess methods.
        # Build SampleBatches for the given episode.
        pre_batches = {}
        for (eps_id, agent_id), collector in self.agent_collectors.items():
            # Build only if there is data and agent is part of given episode.
            if collector.agent_steps == 0 or eps_id != episode_id:
                continue
            pid = self.agent_key_to_policy_id[(eps_id, agent_id)]
            policy = self.policy_map[pid]
            pre_batch = collector.build(policy.view_requirements)
            pre_batches[agent_id] = (policy, pre_batch)

        # Apply reward clipping before calling postprocessing functions.
        if self.clip_rewards is True:
            for _, (_, pre_batch) in pre_batches.items():
                pre_batch["rewards"] = np.sign(pre_batch["rewards"])
        elif self.clip_rewards:
            for _, (_, pre_batch) in pre_batches.items():
                pre_batch["rewards"] = np.clip(pre_batch["rewards"],
                                               a_min=-self.clip_rewards,
                                               a_max=self.clip_rewards)

        post_batches = {}
        for agent_id, (_, pre_batch) in pre_batches.items():
            # Entire episode is said to be done.
            # Error if no DONE at end of this agent's trajectory.
            if is_done and check_dones and \
                    not pre_batch[SampleBatch.DONES][-1]:
                raise ValueError(
                    "Episode {} terminated for all agents, but we still"
                    "don't have a last observation for agent {} (policy "
                    "{}). ".format(
                        episode_id, agent_id, self.agent_key_to_policy_id[
                            (episode_id, agent_id)]) +
                    "Please ensure that you include the last observations "
                    "of all live agents when setting done[__all__] to "
                    "True. Alternatively, set no_done_at_end=True to "
                    "allow this.")

            other_batches = pre_batches.copy()
            del other_batches[agent_id]
            pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
            policy = self.policy_map[pid]
            if any(pre_batch[SampleBatch.DONES][:-1]) or len(
                    set(pre_batch[SampleBatch.EPS_ID])) > 1:
                raise ValueError(
                    "Batches sent to postprocessing must only contain steps "
                    "from a single trajectory.", pre_batch)
            # Call the Policy's Exploration's postprocess method.
            post_batches[agent_id] = pre_batch
            if getattr(policy, "exploration", None) is not None:
                policy.exploration.postprocess_trajectory(
                    policy, post_batches[agent_id],
                    getattr(policy, "_sess", None))
            post_batches[agent_id] = policy.postprocess_trajectory(
                post_batches[agent_id], other_batches, episode)

        if log_once("after_post"):
            logger.info(
                "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
                format(summarize(post_batches)))

        # Append into policy batches and reset.
        from ray.rllib.evaluation.rollout_worker import get_global_worker
        for agent_id, post_batch in sorted(post_batches.items()):
            agent_key = (episode_id, agent_id)
            pid = self.agent_key_to_policy_id[agent_key]
            policy = self.policy_map[pid]
            self.callbacks.on_postprocess_trajectory(
                worker=get_global_worker(),
                episode=episode,
                agent_id=agent_id,
                policy_id=pid,
                policies=self.policy_map,
                postprocessed_batch=post_batch,
                original_batches=pre_batches)
            # Add the postprocessed SampleBatch to the policy collectors for
            # training.
            policy_collector_group.policy_collectors[
                pid].add_postprocessed_batch_for_training(
                    post_batch, policy.view_requirements)

            if is_done:
                del self.agent_key_to_policy_id[agent_key]
                del self.agent_collectors[agent_key]

        if policy_collector_group:
            env_steps = self.episode_steps[episode_id]
            policy_collector_group.env_steps += env_steps
            agent_steps = self.agent_steps[episode_id]
            policy_collector_group.agent_steps += agent_steps

        if is_done:
            del self.episode_steps[episode_id]
            del self.agent_steps[episode_id]
            del self.episodes[episode_id]
            # Make PolicyCollectorGroup available for more agent batches in
            # other episodes. Do not reset count to 0.
            if policy_collector_group:
                self.policy_collector_groups.append(policy_collector_group)
        else:
            self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0

        # Build a MultiAgentBatch from the episode and return.
        if build:
            return self._build_multi_agent_batch(episode)
Ejemplo n.º 29
0
 def _debug_vars(self):
     if log_once("grad_vars"):
         for _, v in self._grads_and_vars:
             logger.info("Optimizing variable {}".format(v))
Ejemplo n.º 30
0
    def _initialize_loss_from_dummy_batch(
            self, auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self._sess.run(tf1.global_variables_initializer())

        if self.config["_use_trajectory_view_api"]:
            logger.info("Testing `compute_actions` w/ dummy batch.")
            actions, state_outs, extra_fetches = \
                self.compute_actions_from_input_dict(
                    self._dummy_batch, explore=False, timestep=0)
            for key, value in extra_fetches.items():
                self._dummy_batch[key] = np.zeros_like(value)
                self._input_dict[key] = get_placeholder(value=value, name=key)
                if key not in self.view_requirements:
                    logger.info("Adding extra-action-fetch `{}` to "
                                "view-reqs.".format(key))
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=value.shape[1:],
                            dtype=value.dtype))
            dummy_batch = self._dummy_batch
        else:

            def fake_array(tensor):
                shape = tensor.shape.as_list()
                shape = [s if s is not None else 1 for s in shape]
                return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

            dummy_batch = {
                SampleBatch.CUR_OBS: fake_array(self._obs_input),
                SampleBatch.NEXT_OBS: fake_array(self._obs_input),
                SampleBatch.DONES: np.array([False], dtype=np.bool),
                SampleBatch.ACTIONS: fake_array(
                    ModelCatalog.get_action_placeholder(self.action_space)),
                SampleBatch.REWARDS: np.array([0], dtype=np.float32),
            }
            if self._obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS: fake_array(
                        self._prev_action_input),
                    SampleBatch.PREV_REWARDS: fake_array(
                        self._prev_reward_input),
                })
            state_init = self.get_initial_state()
            state_batches = []
            for i, h in enumerate(state_init):
                dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
                dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
                state_batches.append(np.expand_dims(h, 0))
            if state_init:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
            for k, v in self.extra_compute_action_fetches().items():
                dummy_batch[k] = fake_array(v)
            dummy_batch = SampleBatch(dummy_batch)

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch, self._sess)
        postprocessed_batch = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        if self.config["_use_trajectory_view_api"]:
            for key in dummy_batch.added_keys:
                if key not in self._input_dict:
                    self._input_dict[key] = get_placeholder(
                        value=dummy_batch[key], name=key)
                if key not in self.view_requirements:
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=dummy_batch[key].shape[1:],
                            dtype=dummy_batch[key].dtype))

        if not self.config["_use_trajectory_view_api"]:
            train_batch = SampleBatch(
                dict({
                    SampleBatch.CUR_OBS: self._obs_input,
                }, **self._loss_input_dict))
            if self._obs_include_prev_action_reward:
                train_batch.update({
                    SampleBatch.PREV_ACTIONS: self._prev_action_input,
                    SampleBatch.PREV_REWARDS: self._prev_reward_input,
                    SampleBatch.CUR_OBS: self._obs_input,
                })

            for k, v in postprocessed_batch.items():
                if k in train_batch:
                    continue
                elif v.dtype == np.object:
                    continue  # can't handle arbitrary objects in TF
                elif k == "seq_lens" or k.startswith("state_in_"):
                    continue
                shape = (None, ) + v.shape[1:]
                dtype = np.float32 if v.dtype == np.float64 else v.dtype
                placeholder = tf1.placeholder(dtype, shape=shape, name=k)
                train_batch[k] = placeholder

            for i, si in enumerate(self._state_inputs):
                train_batch["state_in_{}".format(i)] = si
        else:
            train_batch = SampleBatch(
                dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})
        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | dummy_batch.accessed_keys | \
            dummy_batch.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(self, loss, [(k, v)
                                               for k, v in train_batch.items()
                                               if k in all_accessed_keys])

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if self.config["_use_trajectory_view_api"] and \
                auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self._sess.run(tf1.global_variables_initializer())