示例#1
0
def _create_embedded_rollout_worker(kwargs, send_fn):
    """Create a local rollout worker and a thread that samples from it.

    Args:
        kwargs (dict): args for the RolloutWorker constructor.
        send_fn (fn): function to send a JSON request to the server.
    """

    # Since the server acts as an input datasource, we have to reset the
    # input config to the default, which runs env rollouts.
    kwargs = kwargs.copy()
    del kwargs["input_creator"]

    # Since the server also acts as an output writer, we might have to reset
    # the output config to the default, i.e. "output": None, otherwise a
    # local rollout worker might write to an unknown output directory
    del kwargs["output_creator"]

    # If server has no env (which is the expected case):
    # Generate a dummy ExternalEnv here using RandomEnv and the
    # given observation/action spaces.
    if kwargs["policy_config"].get("env") is None:
        from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv

        config = {
            "action_space": kwargs["policy_config"]["action_space"],
            "observation_space": kwargs["policy_config"]["observation_space"],
        }
        _, is_ma = check_multi_agent(kwargs["policy_config"])
        kwargs["env_creator"] = _auto_wrap_external(
            lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(config)
        )
        kwargs["policy_config"]["env"] = True
    # Otherwise, use the env specified by the server args.
    else:
        real_env_creator = kwargs["env_creator"]
        kwargs["env_creator"] = _auto_wrap_external(real_env_creator)

    logger.info("Creating rollout worker with kwargs={}".format(kwargs))
    from ray.rllib.evaluation.rollout_worker import RolloutWorker

    rollout_worker = RolloutWorker(**kwargs)

    inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
    inference_thread.start()

    return rollout_worker, inference_thread
示例#2
0
def check_train_results(train_results):
    """Checks proper structure of a Trainer.train() returned dict.

    Args:
        train_results: The train results dict to check.

    Raises:
        AssertionError: If `train_results` doesn't have the proper structure or
            data in it.
    """
    # Import these here to avoid circular dependencies.
    from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
    from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
        LEARNER_STATS_KEY
    from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent

    # Assert that some keys are where we would expect them.
    for key in [
            "agent_timesteps_total",
            "config",
            "custom_metrics",
            "episode_len_mean",
            "episode_reward_max",
            "episode_reward_mean",
            "episode_reward_min",
            "episodes_total",
            "hist_stats",
            "info",
            "iterations_since_restore",
            "num_healthy_workers",
            "perf",
            "policy_reward_max",
            "policy_reward_mean",
            "policy_reward_min",
            "sampler_perf",
            "time_since_restore",
            "time_this_iter_s",
            "timesteps_since_restore",
            "timesteps_total",
            "timers",
            "time_total_s",
            "training_iteration",
    ]:
        assert key in train_results, \
            f"'{key}' not found in `train_results` ({train_results})!"

    _, is_multi_agent = check_multi_agent(train_results["config"])

    # Check in particular the "info" dict.
    info = train_results["info"]
    assert LEARNER_INFO in info, \
        f"'learner' not in train_results['infos'] ({info})!"
    assert "num_steps_trained" in info or "num_env_steps_trained" in info, \
        f"'num_(env_)?steps_trained' not in train_results['infos'] ({info})!"

    learner_info = info[LEARNER_INFO]

    # Make sure we have a default_policy key if we are not in a
    # multi-agent setup.
    if not is_multi_agent:
        # APEX algos sometimes have an empty learner info dict (no metrics
        # collected yet).
        assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \
            f"'{DEFAULT_POLICY_ID}' not found in " \
            f"train_results['infos']['learner'] ({learner_info})!"

    for pid, policy_stats in learner_info.items():
        if pid == "batch_count":
            continue
        # Expect td-errors to be per batch-item.
        if "td_error" in policy_stats:
            configured_b = train_results["config"]["train_batch_size"]
            actual_b = policy_stats["td_error"].shape[0]
            # R2D2 case.
            if (configured_b - actual_b) / actual_b > 0.1:
                assert configured_b / (
                    train_results["config"]["model"]["max_seq_len"] +
                    train_results["config"]["burn_in"]) == actual_b

        # Make sure each policy has the LEARNER_STATS_KEY under it.
        assert LEARNER_STATS_KEY in policy_stats
        learner_stats = policy_stats[LEARNER_STATS_KEY]
        for key, value in learner_stats.items():
            # Min- and max-stats should be single values.
            if key.startswith("min_") or key.startswith("max_"):
                assert np.isscalar(
                    value), f"'key' value not a scalar ({value})!"

    return train_results