コード例 #1
0
ファイル: pool.py プロジェクト: stjordanis/ray
    def __init__(self,
                 processes=None,
                 initializer=None,
                 initargs=None,
                 maxtasksperchild=None,
                 context=None,
                 ray_address=None):
        self._closed = False
        self._initializer = initializer
        self._initargs = initargs
        self._maxtasksperchild = maxtasksperchild or -1
        self._actor_deletion_ids = []
        self._registry: List[Tuple[Any, ray.ObjectRef]] = []
        self._registry_hashable: Dict[Hashable, ray.ObjectRef] = {}

        if context and log_once("context_argument_warning"):
            logger.warning("The 'context' argument is not supported using "
                           "ray. Please refer to the documentation for how "
                           "to control ray initialization.")

        processes = self._init_ray(processes, ray_address)
        self._start_actor_pool(processes)
コード例 #2
0
    def action_space_contains(self, x: MultiAgentDict) -> bool:
        """Checks if the action space contains the given action.

        Args:
            x: Actions to check.

        Returns:
            True if the action space contains all actions in x.
        """
        if (
            not hasattr(self, "_spaces_in_preferred_format")
            or self._spaces_in_preferred_format is None
        ):
            self._spaces_in_preferred_format = (
                self._check_if_space_maps_agent_id_to_sub_space()
            )
        if self._spaces_in_preferred_format:
            return self.action_space.contains(x)

        if log_once("action_space_contains"):
            logger.warning("action_space_contains() has not been implemented")
        return True
コード例 #3
0
ファイル: episode.py プロジェクト: rlan/ray
    def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
        """Returns and stores the policy ID for the specified agent.

        If the agent is new, the policy mapping fn will be called to bind the
        agent to a policy for the duration of the entire episode (even if the
        policy_mapping_fn is changed in the meantime).

        Args:
            agent_id (AgentID): The agent ID to lookup the policy ID for.

        Returns:
            PolicyID: The policy ID for the specified agent.
        """

        if agent_id not in self._agent_to_policy:
            # Try new API: pass in agent_id and episode as named args.
            # New signature should be: (agent_id, episode, worker, **kwargs)
            try:
                policy_id = self._agent_to_policy[agent_id] = \
                    self.policy_mapping_fn(agent_id, self, worker=self.worker)
            except TypeError as e:
                if "positional argument" in e.args[0] or \
                        "unexpected keyword argument" in e.args[0]:
                    if log_once("policy_mapping_new_signature"):
                        deprecation_warning(
                            old="policy_mapping_fn(agent_id)",
                            new="policy_mapping_fn(agent_id, episode, "
                            "worker, **kwargs)")
                    policy_id = self._agent_to_policy[agent_id] = \
                        self.policy_mapping_fn(agent_id)
                else:
                    raise e
        else:
            policy_id = self._agent_to_policy[agent_id]
        if policy_id not in self.policy_map:
            raise KeyError("policy_mapping_fn returned invalid policy id "
                           f"'{policy_id}'!")
        return policy_id
コード例 #4
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            try:
                checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                    checkpoint_data)
            except FileNotFoundError:
                if log_once("checkpoint_not_available"):
                    logger.error(
                        f"The requested checkpoint is not available on this node, "
                        f"most likely because you are using Ray client or disabled "
                        f"checkpoint synchronization. To avoid this, enable checkpoint "
                        f"synchronization to cloud storage by specifying a "
                        f"`SyncConfig`. The checkpoint may be available on a different "
                        f"node - please check this location on worker nodes: "
                        f"{checkpoint_data}")
                return None
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            checkpoint = Checkpoint.from_bytes(checkpoint_data)
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
コード例 #5
0
def warn_structure_refactor(old_module: str, new_module: str, direct: bool = True):
    old_module = old_module.replace(".py", "")
    if log_once(f"tune:structure:refactor:{old_module}"):
        warning = (
            f"The module `{old_module}` has been moved to `{new_module}` and the old "
            f"location will be deprecated soon. Please adjust your imports to point "
            f"to the new location."
        )

        if direct:
            warning += (
                f"Example: Do a global search and "
                f"replace `{old_module}` with `{new_module}`."
            )

        else:
            warning += (
                f"ATTENTION: This module may have been split or refactored. Please "
                f"check the contents of `{new_module}` before making changes."
            )

        with warnings.catch_warnings():
            warnings.simplefilter("always")
            warnings.warn(warning, DeprecationWarning)
コード例 #6
0
ファイル: rnn_sequencing.py プロジェクト: zhangbushi10/ray
def pad_batch_to_sequences_of_same_size(
    batch: SampleBatch,
    max_seq_len: int,
    shuffle: bool = False,
    batch_divisibility_req: int = 1,
    feature_keys: Optional[List[str]] = None,
    view_requirements: Optional[ViewRequirementsDict] = None,
):
    """Applies padding to `batch` so it's choppable into same-size sequences.

    Shuffles `batch` (if desired), makes sure divisibility requirement is met,
    then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
    adding a time dimension (yet).
    Padding depends on episodes found in batch and `max_seq_len`.

    Args:
        batch (SampleBatch): The SampleBatch object. All values in here have
            the shape [B, ...].
        max_seq_len (int): The max. sequence length to use for chopping.
        shuffle (bool): Whether to shuffle batch sequences. Shuffle may
            be done in-place. This only makes sense if you're further
            applying minibatch SGD after getting the outputs.
        batch_divisibility_req (int): The int by which the batch dimension
            must be dividable.
        feature_keys (Optional[List[str]]): An optional list of keys to apply
            sequence-chopping to. If None, use all keys in batch that are not
            "state_in/out_"-type keys.
        view_requirements (Optional[ViewRequirementsDict]): An optional
            Policy ViewRequirements dict to be able to infer whether
            e.g. dynamic max'ing should be applied over the seq_lens.
    """
    if batch_divisibility_req > 1:
        meets_divisibility_reqs = (
            len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
            # not multiagent
            and max(batch[SampleBatch.AGENT_INDEX]) == 0)
    else:
        meets_divisibility_reqs = True

    states_already_reduced_to_init = False

    # RNN/attention net case. Figure out whether we should apply dynamic
    # max'ing over the list of sequence lengths.
    if "state_in_0" in batch or "state_out_0" in batch:
        # Check, whether the state inputs have already been reduced to their
        # init values at the beginning of each max_seq_len chunk.
        if batch.seq_lens is not None and \
                len(batch["state_in_0"]) == len(batch.seq_lens):
            states_already_reduced_to_init = True

        # RNN (or single timestep state-in): Set the max dynamically.
        if view_requirements["state_in_0"].shift_from is None:
            dynamic_max = True
        # Attention Nets (state inputs are over some range): No dynamic maxing
        # possible.
        else:
            dynamic_max = False
    # Multi-agent case.
    elif not meets_divisibility_reqs:
        max_seq_len = batch_divisibility_req
        dynamic_max = False
    # Simple case: No RNN/attention net, nor do we need to pad.
    else:
        if shuffle:
            batch.shuffle()
        return

    # RNN, attention net, or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k, v in batch.items():
        if k.startswith("state_in_"):
            state_keys.append(k)
        elif not feature_keys and not k.startswith("state_out_") and \
                k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray):
            feature_keys_.append(k)

    feature_sequences, initial_states, seq_lens = \
        chop_into_sequences(
            feature_columns=[batch[k] for k in feature_keys_],
            state_columns=[batch[k] for k in state_keys],
            episode_ids=batch.get(SampleBatch.EPS_ID),
            unroll_ids=batch.get(SampleBatch.UNROLL_ID),
            agent_indices=batch.get(SampleBatch.AGENT_INDEX),
            seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")),
            max_seq_len=max_seq_len,
            dynamic_max=dynamic_max,
            states_already_reduced_to_init=states_already_reduced_to_init,
            shuffle=shuffle)

    for i, k in enumerate(feature_keys_):
        batch[k] = feature_sequences[i]
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch["seq_lens"] = np.array(seq_lens)

    if log_once("rnn_ma_feed_dict"):
        logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
            summarize({
                "features": feature_sequences,
                "initial_states": initial_states,
                "seq_lens": seq_lens,
                "max_seq_len": max_seq_len,
            })))
コード例 #7
0
ファイル: sample_batch.py プロジェクト: javigm98/ray
 def data(self):
     if log_once("SampleBatch.data"):
         deprecation_warning(old="SampleBatch.data[..]",
                             new="SampleBatch[..]",
                             error=False)
     return self
コード例 #8
0
    def get_best_checkpoint(
        self,
        trial: Trial,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        return_path: bool = False,
    ) -> Optional[Union[Checkpoint, str]]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.
            return_path: If True, only returns the path (and not the
                ``Checkpoint`` object). If using Ray client, it is not
                guaranteed that this path is available on the local
                (client) node. Can also contain a cloud URI.

        Returns:
            :class:`Checkpoint <ray.air.Checkpoint>` object or string
            if ``return_path=True``.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            if return_path:
                return cloud_path
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            if return_path:
                return best_path
            return Checkpoint.from_directory(best_path)
        else:
            if log_once("checkpoint_not_available"):
                logger.error(
                    f"The requested checkpoint for trial {trial} is not available on "
                    f"this node, most likely because you are using Ray client or "
                    f"disabled checkpoint synchronization. To avoid this, enable "
                    f"checkpoint synchronization to cloud storage by specifying a "
                    f"`SyncConfig`. The checkpoint may be available on a different "
                    f"node - please check this location on worker nodes: {best_path}"
                )
            if return_path:
                return best_path
            return None
コード例 #9
0
ファイル: rnn_sequencing.py プロジェクト: zhongchun/ray
def pad_batch_to_sequences_of_same_size(batch,
                                        max_seq_len,
                                        shuffle=False,
                                        batch_divisibility_req=1,
                                        feature_keys=None):
    """Applies padding to `batch` so it's choppable into same-size sequences.

    Shuffles `batch` (if desired), makes sure divisibility requirement is met,
    then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
    adding a time dimension (yet).
    Padding depends on episodes found in batch and `max_seq_len`.

    Args:
        batch (SampleBatch): The SampleBatch object. All values in here have
            the shape [B, ...].
        max_seq_len (int): The max. sequence length to use for chopping.
        shuffle (bool): Whether to shuffle batch sequences. Shuffle may
            be done in-place. This only makes sense if you're further
            applying minibatch SGD after getting the outputs.
        batch_divisibility_req (int): The int by which the batch dimension
            must be dividable.
        feature_keys (Optional[List[str]]): An optional list of keys to apply
            sequence-chopping to. If None, use all keys in batch that are not
            "state_in/out_"-type keys.
    """
    if batch_divisibility_req > 1:
        meets_divisibility_reqs = (
            len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
            # not multiagent
            and max(batch[SampleBatch.AGENT_INDEX]) == 0)
    else:
        meets_divisibility_reqs = True

    # RNN-case.
    if "state_in_0" in batch:
        dynamic_max = True
    # Multi-agent case.
    elif not meets_divisibility_reqs:
        max_seq_len = batch_divisibility_req
        dynamic_max = False
    # Simple case: not RNN nor do we need to pad.
    else:
        if shuffle:
            batch.shuffle()
        return

    # RNN or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k in batch.keys():
        if "state_in_" in k:
            state_keys.append(k)
        elif not feature_keys and "state_out_" not in k and k != "infos":
            feature_keys_.append(k)

    feature_sequences, initial_states, seq_lens = \
        chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX],
            [batch[k] for k in feature_keys_],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
    for i, k in enumerate(feature_keys_):
        batch[k] = feature_sequences[i]
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch["seq_lens"] = seq_lens

    if log_once("rnn_ma_feed_dict"):
        logger.info("Padded input for RNN:\n\n{}\n".format(
            summarize({
                "features": feature_sequences,
                "initial_states": initial_states,
                "seq_lens": seq_lens,
                "max_seq_len": max_seq_len,
            })))
コード例 #10
0
ファイル: tf_modelv2.py プロジェクト: tuyulers5/jav44
 def register_variables(self, variables: List[TensorType]) -> None:
     """Register the given list of variables with this model."""
     if log_once("deprecated_tfmodelv2_register_variables"):
         deprecation_warning(old="TFModelV2.register_variables",
                             error=False)
     self.var_list.extend(variables)
コード例 #11
0
ファイル: rnn_sequencing.py プロジェクト: vishalbelsare/ray
def timeslice_along_seq_lens_with_overlap(
    sample_batch: SampleBatchType,
    seq_lens: Optional[List[int]] = None,
    zero_pad_max_seq_len: int = 0,
    pre_overlap: int = 0,
    zero_init_states: bool = True,
) -> List["SampleBatch"]:
    """Slices batch along `seq_lens` (each seq-len item produces one batch).

    Args:
        sample_batch: The SampleBatch to timeslice.
        seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
            at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`.
        zero_pad_max_seq_len: If >0, already zero-pad the resulting
            slices up to this length. NOTE: This max-len will include the
            additional timesteps gained via setting pre_overlap (see Example).
        pre_overlap: If >0, will overlap each two consecutive slices by
            this many timesteps (toward the left side). This will cause
            zero-padding at the very beginning of the batch.
        zero_init_states: Whether initial states should always be
            zero'd. If False, will use the state_outs of the batch to
            populate state_in values.

    Returns:
        List[SampleBatch]: The list of (new) SampleBatches.

    Examples:
        assert seq_lens == [5, 5, 2]
        assert sample_batch.count == 12
        # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
        slices = timeslice_along_seq_lens_with_overlap(
            sample_batch=sample_batch.
            zero_pad_max_seq_len=10,
            pre_overlap=3)
        # Z = zero padding (at beginning or end).
        #             |pre (3)|     seq     | max-seq-len (up to 10)
        # slices[0] = | Z Z Z |  0  1 2 3 4 | Z Z
        # slices[1] = | 2 3 4 |  5  6 7 8 9 | Z Z
        # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
        # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
        #  count (makes sure each slice has exactly length 10).
    """
    if seq_lens is None:
        seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
    else:
        if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once(
            "overriding_sequencing_information"
        ):
            logger.warning(
                "Found sequencing information in a batch that will be "
                "ignored when slicing. Ignore this warning if you know "
                "what you are doing."
            )

    if seq_lens is None:
        max_seq_len = zero_pad_max_seq_len - pre_overlap
        if log_once("no_sequence_lengths_available_for_time_slicing"):
            logger.warning(
                "Trying to slice a batch along sequences without "
                "sequence lengths being provided in the batch. Batch will "
                "be sliced into slices of size "
                "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
                    max_seq_len, zero_pad_max_seq_len, pre_overlap
                )
            )
        num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
        seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
            [last_seq_len] if last_seq_len else []
        )

    assert (
        seq_lens is not None and len(seq_lens) > 0
    ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
    # Generate n slices based on seq_lens.
    start = 0
    slices = []
    for seq_len in seq_lens:
        pre_begin = start - pre_overlap
        slice_begin = start
        end = start + seq_len
        slices.append((pre_begin, slice_begin, end))
        start += seq_len

    timeslices = []
    for begin, slice_begin, end in slices:
        zero_length = None
        data_begin = 0
        zero_init_states_ = zero_init_states
        if begin < 0:
            zero_length = pre_overlap
            data_begin = slice_begin
            zero_init_states_ = True
        else:
            eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end]
            is_last_episode_ids = eps_ids == eps_ids[-1]
            if not is_last_episode_ids[0]:
                zero_length = int(sum(1.0 - is_last_episode_ids))
                data_begin = begin + zero_length
                zero_init_states_ = True

        if zero_length is not None:
            data = {
                k: np.concatenate(
                    [
                        np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype),
                        v[data_begin:end],
                    ]
                )
                for k, v in sample_batch.items()
                if k != SampleBatch.SEQ_LENS
            }
        else:
            data = {
                k: v[begin:end]
                for k, v in sample_batch.items()
                if k != SampleBatch.SEQ_LENS
            }

        if zero_init_states_:
            i = 0
            key = "state_in_{}".format(i)
            while key in data:
                data[key] = np.zeros_like(sample_batch[key][0:1])
                # Del state_out_n from data if exists.
                data.pop("state_out_{}".format(i), None)
                i += 1
                key = "state_in_{}".format(i)
        # TODO: This will not work with attention nets as their state_outs are
        #  not compatible with state_ins.
        else:
            i = 0
            key = "state_in_{}".format(i)
            while key in data:
                data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin]
                del data["state_out_{}".format(i)]
                i += 1
                key = "state_in_{}".format(i)

        timeslices.append(SampleBatch(data, seq_lens=[end - begin]))

    # Zero-pad each slice if necessary.
    if zero_pad_max_seq_len > 0:
        for ts in timeslices:
            ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)

    return timeslices
コード例 #12
0
ファイル: env.py プロジェクト: bladesaber/ray
def check_multiagent_environments(env: "MultiAgentEnv") -> None:
    """Checking for common errors in RLlib MultiAgentEnvs.

    Args:
        env: The env to be checked.

    """
    from ray.rllib.env import MultiAgentEnv

    if not isinstance(env, MultiAgentEnv):
        raise ValueError("The passed env is not a MultiAgentEnv.")
    elif not (hasattr(env, "observation_space")
              and hasattr(env, "action_space") and hasattr(env, "_agent_ids")
              and hasattr(env, "_spaces_in_preferred_format")):
        if log_once("ma_env_super_ctor_called"):
            logger.warning(
                f"Your MultiAgentEnv {env} does not have some or all of the needed "
                "base-class attributes! Make sure you call `super().__init__` from "
                "within your MutiAgentEnv's constructor. "
                "This will raise an error in the future.")
        env.observation_space = (
            env.action_space) = env._spaces_in_preferred_format = None
        env._agent_ids = set()

    reset_obs = env.reset()
    sampled_obs = env.observation_space_sample()
    _check_if_element_multi_agent_dict(env, reset_obs, "reset()")
    _check_if_element_multi_agent_dict(env, sampled_obs,
                                       "env.observation_space_sample()")

    try:
        env.observation_space_contains(reset_obs)
    except Exception as e:
        raise ValueError(
            "Your observation_space_contains function has some error ") from e

    if not env.observation_space_contains(reset_obs):
        error = (
            _not_contained_error("env.reset", "observation") +
            f"\n\n reset_obs: {reset_obs}\n\n env.observation_space_sample():"
            f" {sampled_obs}\n\n ")
        raise ValueError(error)

    if not env.observation_space_contains(sampled_obs):
        error = (
            _not_contained_error("observation_space_sample", "observation") +
            f"\n\n env.observation_space_sample():"
            f" {sampled_obs}\n\n ")
        raise ValueError(error)

    sampled_action = env.action_space_sample()
    _check_if_element_multi_agent_dict(env, sampled_action,
                                       "action_space_sample")
    try:
        env.action_space_contains(sampled_action)
    except Exception as e:
        raise ValueError(
            "Your action_space_contains function has some error ") from e

    if not env.action_space_contains(sampled_action):
        error = (_not_contained_error("action_space_sample", "action") +
                 "\n\n sampled_action {sampled_action}\n\n")
        raise ValueError(error)

    next_obs, reward, done, info = env.step(sampled_action)
    _check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
    _check_if_element_multi_agent_dict(env, reward, "step, reward")
    _check_if_element_multi_agent_dict(env, done, "step, done")
    _check_if_element_multi_agent_dict(env, info, "step, info")
    _check_reward({"dummy_env_id": reward},
                  base_env=True,
                  agent_ids=env.get_agent_ids())
    _check_done({"dummy_env_id": done},
                base_env=True,
                agent_ids=env.get_agent_ids())
    _check_info({"dummy_env_id": info},
                base_env=True,
                agent_ids=env.get_agent_ids())
    if not env.observation_space_contains(next_obs):
        error = (
            _not_contained_error("env.step(sampled_action)", "observation") +
            ":\n\n next_obs: {next_obs} \n\n sampled_obs: {sampled_obs}")
        raise ValueError(error)
コード例 #13
0
ファイル: torch_policy_template.py プロジェクト: yncxcw/ray
def build_torch_policy(
    name: str,
    *,
    loss_fn: Optional[Callable[
        [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
        Union[TensorType, List[TensorType]]]],
    get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
    stats_fn: Optional[Callable[[Policy, SampleBatch],
                                Dict[str, TensorType]]] = None,
    postprocess_fn=None,
    extra_action_out_fn: Optional[Callable[[
        Policy, Dict[
            str,
            TensorType], List[TensorType], ModelV2, TorchDistributionWrapper
    ], Dict[str, TensorType]]] = None,
    extra_grad_process_fn: Optional[
        Callable[[Policy, "torch.optim.Optimizer", TensorType],
                 Dict[str, TensorType]]] = None,
    extra_learn_fetches_fn: Optional[Callable[[Policy],
                                              Dict[str, TensorType]]] = None,
    optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
                                    "torch.optim.Optimizer"]] = None,
    validate_spaces: Optional[Callable[
        [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
    before_init: Optional[Callable[
        [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
    before_loss_init: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        None]] = None,
    after_init: Optional[Callable[
        [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
    _after_loss_init: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        None]] = None,
    action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                         Tuple[TensorType,
                                               TensorType]]] = None,
    action_distribution_fn: Optional[
        Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                 Tuple[TensorType, type, List[TensorType]]]] = None,
    make_model: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        ModelV2]] = None,
    make_model_and_action_dist: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None,
    compute_gradients_fn: Optional[Callable[[Policy, SampleBatch],
                                            Tuple[ModelGradients,
                                                  dict]]] = None,
    apply_gradients_fn: Optional[Callable[[Policy, "torch.optim.Optimizer"],
                                          None]] = None,
    mixins: Optional[List[type]] = None,
    get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:

    if log_once("deprecation_warning_build_torch_policy"):
        deprecation_warning(old="build_torch_policy",
                            new="build_policy_class(framework='torch')",
                            error=False)
    kwargs = locals().copy()
    # Set to torch and call new function.
    kwargs["framework"] = "torch"
    return build_policy_class(**kwargs)
コード例 #14
0
def check_gym_environments(env: gym.Env) -> None:
    """Checking for common errors in gym environments.

    Args:
        env: Environment to be checked.

    Warning:
        If env has no attribute spec with a sub attribute,
            max_episode_steps.

    Raises:
        AttributeError: If env has no observation space.
        AttributeError: If env has no action space.
        ValueError: Observation space must be a gym.spaces.Space.
        ValueError: Action space must be a gym.spaces.Space.
        ValueError: Observation sampled from observation space must be
            contained in the observation space.
        ValueError: Action sampled from action space must be
            contained in the observation space.
        ValueError: If env cannot be resetted.
        ValueError: If an observation collected from a call to env.reset().
            is not contained in the observation_space.
        ValueError: If env cannot be stepped via a call to env.step().
        ValueError: If the observation collected from env.step() is not
            contained in the observation_space.
        AssertionError: If env.step() returns a reward that is not an
            int or float.
        AssertionError: IF env.step() returns a done that is not a bool.
        AssertionError: If env.step() returns an env_info that is not a dict.
    """

    # check that env has observation and action spaces
    if not hasattr(env, "observation_space"):
        raise AttributeError("Env must have observation_space.")
    if not hasattr(env, "action_space"):
        raise AttributeError("Env must have action_space.")

    # check that observation and action spaces are gym.spaces
    if not isinstance(env.observation_space, gym.spaces.Space):
        raise ValueError("Observation space must be a gym.space")
    if not isinstance(env.action_space, gym.spaces.Space):
        raise ValueError("Action space must be a gym.space")

    # Raise a warning if there isn't a max_episode_steps attribute.
    if not hasattr(env, "spec") or not hasattr(env.spec, "max_episode_steps"):
        if log_once("max_episode_steps"):
            logger.warning(
                "Your env doesn't have a .spec.max_episode_steps "
                "attribute. This is fine if you have set 'horizon' "
                "in your config dictionary, or `soft_horizon`. "
                "However, if you haven't, 'horizon' will default "
                "to infinity, and your environment will not be "
                "reset."
            )
    # check if sampled actions and observations are contained within their
    # respective action and observation spaces.

    def get_type(var):
        return var.dtype if hasattr(var, "dtype") else type(var)

    sampled_action = env.action_space.sample()
    sampled_observation = env.observation_space.sample()
    # check if observation generated from stepping the environment is
    # contained within the observation space
    reset_obs = env.reset()
    if not env.observation_space.contains(reset_obs):
        reset_obs_type = get_type(reset_obs)
        space_type = env.observation_space.dtype
        error = (
            f"The observation collected from env.reset() was not  "
            f"contained within your env's observation space. Its possible "
            f"that There was a type mismatch, or that one of the "
            f"sub-observations  was out of bounds: \n\n reset_obs: "
            f"{reset_obs}\n\n env.observation_space: "
            f"{env.observation_space}\n\n reset_obs's dtype: "
            f"{reset_obs_type}\n\n env.observation_space's dtype: "
            f"{space_type}"
        )
        temp_sampled_reset_obs = convert_element_to_space_type(
            reset_obs, sampled_observation
        )
        if not env.observation_space.contains(temp_sampled_reset_obs):
            raise ValueError(error)
    # check if env.step can run, and generates observations rewards, done
    # signals and infos that are within their respective spaces and are of
    # the correct dtypes
    next_obs, reward, done, info = env.step(sampled_action)
    if not env.observation_space.contains(next_obs):
        next_obs_type = get_type(next_obs)
        space_type = env.observation_space.dtype
        error = (
            f"The observation collected from env.step(sampled_action) was "
            f"not contained within your env's observation space. Its "
            f"possible that There was a type mismatch, or that one of the "
            f"sub-observations was out of bounds:\n\n next_obs: {next_obs}"
            f"\n\n env.observation_space: {env.observation_space}"
            f"\n\n next_obs's dtype: {next_obs_type}"
            f"\n\n env.observation_space's dtype: {space_type}"
        )
        temp_sampled_next_obs = convert_element_to_space_type(
            next_obs, sampled_observation
        )
        if not env.observation_space.contains(temp_sampled_next_obs):
            raise ValueError(error)
    _check_done(done)
    _check_reward(reward)
    _check_info(info)
コード例 #15
0
def check_env(env: EnvType) -> None:
    """Run pre-checks on env that uncover common errors in environments.

    Args:
        env: Environment to be checked.

    Raises:
        ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES.
        ValueError: See check_gym_env docstring for details.
    """
    from ray.rllib.env import (
        BaseEnv,
        MultiAgentEnv,
        RemoteBaseEnv,
        VectorEnv,
        ExternalMultiAgentEnv,
        ExternalEnv,
    )

    if hasattr(env, "_skip_env_checking") and env._skip_env_checking:
        # This is a work around for some environments that we already have in RLlb
        # that we want to skip checking for now until we have the time to fix them.
        if log_once("skip_env_checking"):
            logger.warning("Skipping env checking for this experiment")
        return

    try:
        if not isinstance(
            env,
            (
                BaseEnv,
                gym.Env,
                MultiAgentEnv,
                RemoteBaseEnv,
                VectorEnv,
                ExternalMultiAgentEnv,
                ExternalEnv,
                ActorHandle,
            ),
        ):
            raise ValueError(
                "Env must be one of the supported types: BaseEnv, gym.Env, "
                "MultiAgentEnv, VectorEnv, RemoteBaseEnv, ExternalMultiAgentEnv, "
                f"ExternalEnv, but instead was a {type(env)}"
            )
        if isinstance(env, MultiAgentEnv):
            check_multiagent_environments(env)
        elif isinstance(env, gym.Env):
            check_gym_environments(env)
        elif isinstance(env, BaseEnv):
            check_base_env(env)
        else:
            logger.warning(
                "Env checking isn't implemented for VectorEnvs, RemoteBaseEnvs, "
                "ExternalMultiAgentEnv,or ExternalEnvs or Environments that are "
                "Ray actors"
            )
    except Exception:
        actual_error = traceback.format_exc()
        raise ValueError(
            f"{actual_error}\n"
            "The above error has been found in your environment! "
            "We've added a module for checking your custom environments. It "
            "may cause your experiment to fail if your environment is not set up"
            "correctly. You can disable this behavior by setting "
            "`disable_env_checking=True` in your config "
            "dictionary. You can run the environment checking module "
            "standalone by calling ray.rllib.utils.check_env([env])."
        )
コード例 #16
0
ファイル: torch_trainer.py プロジェクト: yynst2/ray
    def __init__(
        self,
        *,
        training_operator_cls,
        initialization_hook=None,
        config=None,
        num_workers=1,
        num_cpus_per_worker=1,
        use_gpu="auto",
        backend="auto",
        wrap_ddp=True,
        timeout_s=NCCL_TIMEOUT_S,
        use_fp16=False,
        use_tqdm=False,
        apex_args=None,
        add_dist_sampler=True,
        scheduler_step_freq=None,
        use_local=False,
        # Deprecated Args.
        num_replicas=None,
        batch_size=None,
        model_creator=None,
        data_creator=None,
        optimizer_creator=None,
        scheduler_creator=None,
        loss_creator=None,
        serialize_data_creation=None,
        data_loader_args=None,
    ):
        if (model_creator or data_creator or optimizer_creator
                or scheduler_creator or loss_creator):
            raise DeprecationWarning(
                "Creator functions are deprecated. You should create a "
                "custom TrainingOperator, override setup, and register all "
                "training state there. See TrainingOperator for more info. "
                "If you would still like to use creator functions, you can "
                "do CustomOperator = TrainingOperator.from_creators("
                "model_creator, ...) and pass in CustomOperator into "
                "TorchTrainer.")

        if use_local and log_once("use_local"):
            logger.warning("use_local is set to True. This could lead to "
                           "issues with Cuda devices. If you are seeing this "
                           "issue, try setting use_local to False. For more "
                           "information, see "
                           "https://github.com/ray-project/ray/issues/9202.")

        if num_workers > 1 and not dist.is_available():
            raise ValueError(
                ("Distributed PyTorch is not supported on macOS. "
                 "To run without distributed PyTorch, set 'num_workers=1'. "
                 "For more information, see "
                 "https://github.com/pytorch/examples/issues/467."))

        if num_replicas is not None:
            raise DeprecationWarning(
                "num_replicas is deprecated. Use num_workers instead.")

        if batch_size is not None:
            raise DeprecationWarning(
                "batch_size is deprecated. Use config={'batch_size': N} "
                "specify a batch size for each worker or "
                "config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
                "batch size to be used across all workers.")

        if serialize_data_creation is True:
            if log_once("serialize_data_creation"):
                logging.warning(
                    "serialize_data_creation is deprecated and will be "
                    "ignored. If you require serialized data loading you "
                    "should implement this in TrainingOperator.setup. "
                    "You may find FileLock useful here.")

        if data_loader_args:
            raise DeprecationWarning(
                "data_loader_args is deprecated. You can return a "
                "torch.utils.data.DataLoader in data_creator. Ray will "
                "automatically set a DistributedSampler if a DataLoader is "
                "returned and num_workers > 1.")

        self.training_operator_cls = training_operator_cls

        self.initialization_hook = initialization_hook
        self.config = {} if config is None else config
        if use_gpu == "auto":
            use_gpu = torch.cuda.is_available()

        _remind_gpu_usage(use_gpu)

        if backend == "auto":
            backend = "nccl" if use_gpu else "gloo"

        logger.debug(f"Using {backend} as backend.")
        self.backend = backend
        self.num_cpus_per_worker = num_cpus_per_worker
        self.use_gpu = use_gpu
        self.max_replicas = num_workers

        self.serialize_data_creation = serialize_data_creation
        self.wrap_ddp = wrap_ddp
        self.timeout_s = timeout_s
        self.use_fp16 = use_fp16
        self.use_tqdm = use_tqdm
        self.add_dist_sampler = add_dist_sampler
        self.use_local = use_local

        if apex_args and not isinstance(apex_args, dict):
            raise ValueError("apex_args needs to be a dict object.")

        self.apex_args = apex_args
        self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
        self._num_failures = 0
        self._last_resize = float("-inf")

        if scheduler_step_freq:
            _validate_scheduler_step_freq(scheduler_step_freq)

        self.scheduler_step_freq = scheduler_step_freq

        if not ray.is_initialized() and self.max_replicas > 1:
            logger.info("Automatically initializing single-node Ray. To use "
                        "multi-node training, be sure to run `ray.init("
                        "address='auto')` before instantiating the Trainer.")
            ray.init()
        self._start_workers(self.max_replicas)
コード例 #17
0
def update_priorities_in_replay_buffer(
    replay_buffer: ReplayBuffer,
    config: TrainerConfigDict,
    train_batch: SampleBatchType,
    train_results: ResultDict,
) -> None:
    """Updates the priorities in a prioritized replay buffer, given training results.

    The `abs(TD-error)` from the loss (inside `train_results`) is used as new
    priorities for the row-indices that were sampled for the train batch.

    Don't do anything if the given buffer does not support prioritized replay.

    Args:
        replay_buffer: The replay buffer, whose priority values to update. This may also
            be a buffer that does not support priorities.
        config: The Trainer's config dict.
        train_batch: The batch used for the training update.
        train_results: A train results dict, generated by e.g. the `train_one_step()`
            utility.
    """
    # Only update priorities if buffer supports them.
    if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
        # Go through training results for the different policies (maybe multi-agent).
        prio_dict = {}
        for policy_id, info in train_results.items():
            # TODO(sven): This is currently structured differently for
            #  torch/tf. Clean up these results/info dicts across
            #  policies (note: fixing this in torch_policy.py will
            #  break e.g. DDPPO!).
            td_error = info.get("td_error",
                                info[LEARNER_STATS_KEY].get("td_error"))
            # Set the get_interceptor to None in order to be able to access the numpy
            # arrays directly (instead of e.g. a torch array).
            train_batch.policy_batches[policy_id].set_get_interceptor(None)
            # Get the replay buffer row indices that make up the `train_batch`.
            batch_indices = train_batch.policy_batches[policy_id].get(
                "batch_indexes")

            if td_error is None:
                if log_once(
                        "no_td_error_in_train_results_from_policy_{}".format(
                            policy_id)):
                    logger.warning(
                        "Trying to update priorities for policy with id `{}` in "
                        "prioritized replay buffer without providing td_errors in "
                        "train_results. Priority update for this policy is being "
                        "skipped.".format(policy_id))
                continue

            if batch_indices is None:
                if log_once("no_batch_indices_in_train_result_for_policy_{}".
                            format(policy_id)):
                    logger.warning(
                        "Trying to update priorities for policy with id `{}` in "
                        "prioritized replay buffer without providing batch_indices in "
                        "train_batch. Priority update for this policy is being "
                        "skipped.".format(policy_id))
                continue

            #  Try to transform batch_indices to td_error dimensions
            if len(batch_indices) != len(td_error):
                T = replay_buffer.replay_sequence_length
                assert (len(batch_indices) > len(td_error)
                        and len(batch_indices) % T == 0)
                batch_indices = batch_indices.reshape([-1, T])[:, 0]
                assert len(batch_indices) == len(td_error)
            prio_dict[policy_id] = (batch_indices, td_error)

        # Make the actual buffer API call to update the priority weights on all
        # policies.
        replay_buffer.update_priorities(prio_dict)
コード例 #18
0
    def _initialize_loss(self):
        def fake_array(tensor, none_shape):
            shape = tensor.shape.as_list()
            non_none_shape = [s for s in shape if s is not None]
            none_shape = none_shape if isinstance(none_shape, list) else [none_shape]
            shape = none_shape + non_none_shape
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        T = self.config["model"]["max_seq_len"]
        B = self.config["train_batch_size"] // T
        dummy_batch = {
            SampleBatch.CUR_OBS: fake_array(self._obs_input, B * T),
            SampleBatch.NEXT_OBS: fake_array(self._obs_input, B * T),
            SampleBatch.DONES: np.array([False] * B * T, dtype=np.bool),
            SampleBatch.ACTIONS: fake_array(
                ModelCatalog.get_action_placeholder(self.action_space), B * T
            ),
            SampleBatch.REWARDS: np.array([0] * B * T, dtype=np.float32),
            SampleBatch.INFOS: np.array([self.sample_info] * B * T),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update(
                {
                    SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input, B * T),
                    SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input, B * T),
                }
            )

        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.repeat(
                np.expand_dims(h, 0), B * T, 0
            )
            dummy_batch["state_out_{}".format(i)] = np.repeat(
                np.expand_dims(h, 0), B * T, 0
            )
            state_batches.append(np.repeat(np.expand_dims(h, 0), B * T, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([T] * B * T, dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v, B * T)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict(
                {
                    SampleBatch.PREV_ACTIONS: self._prev_action_input,
                    SampleBatch.PREV_REWARDS: self._prev_reward_input,
                    SampleBatch.CUR_OBS: self._obs_input,
                }
            )
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({SampleBatch.CUR_OBS: self._obs_input})
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None,) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)
                )
            )

        self._loss_input_dict = train_batch
        loss = self._do_loss_init(train_batch)
        for k in sorted(train_batch.accessed_keys):
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads)
            )
        self._sess.run(tf.global_variables_initializer())