Exemplo n.º 1
0
    def __init__(
        self,
        num_shards: int = 1,
        learning_starts: int = 1000,
        capacity: int = 10000,
        replay_batch_size: int = 1,
        prioritized_replay_alpha: float = 0.6,
        prioritized_replay_beta: float = 0.4,
        prioritized_replay_eps: float = 1e-6,
        replay_mode: str = "independent",
        replay_sequence_length: int = 1,
        replay_burn_in: int = 0,
        replay_zero_init_states: bool = True,
        buffer_size=DEPRECATED_VALUE,
    ):
        """Initializes a MultiAgentReplayBuffer instance.

        Args:
            num_shards: The number of buffer shards that exist in total
                (including this one).
            learning_starts: Number of timesteps after which a call to
                `replay()` will yield samples (before that, `replay()` will
                return None).
            capacity: The capacity of the buffer. Note that when
                `replay_sequence_length` > 1, this is the number of sequences
                (not single timesteps) stored.
            replay_batch_size: The batch size to be sampled (in timesteps).
                Note that if `replay_sequence_length` > 1,
                `self.replay_batch_size` will be set to the number of
                sequences sampled (B).
            prioritized_replay_alpha: Alpha parameter for a prioritized
                replay buffer. Use 0.0 for no prioritization.
            prioritized_replay_beta: Beta parameter for a prioritized
                replay buffer.
            prioritized_replay_eps: Epsilon parameter for a prioritized
                replay buffer.
            replay_mode: One of "independent" or "lockstep". Determined,
                whether in the multiagent case, sampling is done across all
                agents/policies equally.
            replay_sequence_length: The sequence length (T) of a single
                sample. If > 1, we will sample B x T from this buffer.
            replay_burn_in: The burn-in length in case
                `replay_sequence_length` > 0. This is the number of timesteps
                each sequence overlaps with the previous one to generate a
                better internal state (=state after the burn-in), instead of
                starting from 0.0 each RNN rollout.
            replay_zero_init_states: Whether the initial states in the
                buffer (if replay_sequence_length > 0) are alwayas 0.0 or
                should be updated with the previous train_batch state outputs.
        """
        # Deprecated args.
        if buffer_size != DEPRECATED_VALUE:
            deprecation_warning("ReplayBuffer(size)",
                                "ReplayBuffer(capacity)",
                                error=False)
            capacity = buffer_size

        self.replay_starts = learning_starts // num_shards
        self.capacity = capacity // num_shards
        self.replay_batch_size = replay_batch_size
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.replay_mode = replay_mode
        self.replay_sequence_length = replay_sequence_length
        self.replay_burn_in = replay_burn_in
        self.replay_zero_init_states = replay_zero_init_states

        if replay_sequence_length > 1:
            self.replay_batch_size = int(
                max(1, replay_batch_size // replay_sequence_length))
            logger.info(
                "Since replay_sequence_length={} and replay_batch_size={}, "
                "we will replay {} sequences at a time.".format(
                    replay_sequence_length, replay_batch_size,
                    self.replay_batch_size))

        if replay_mode not in ["lockstep", "independent"]:
            raise ValueError("Unsupported replay mode: {}".format(replay_mode))

        def gen_replay():
            while True:
                yield self.replay()

        ParallelIteratorWorker.__init__(self, gen_replay, False)

        def new_buffer():
            if prioritized_replay_alpha == 0.0:
                return ReplayBuffer(self.capacity)
            else:
                return PrioritizedReplayBuffer(self.capacity,
                                               alpha=prioritized_replay_alpha)

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics.
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.update_priorities_timer = TimerStat()
        self.num_added = 0

        # Make externally accessible for testing.
        global _local_replay_buffer
        _local_replay_buffer = self
        # If set, return this instead of the usual data for testing.
        self._fake_batch = None
Exemplo n.º 2
0
def validate_buffer_config(config: dict) -> None:
    """Checks and fixes values in the replay buffer config.

    Checks the replay buffer config for common misconfigurations, warns or raises
    error in case validation fails. The type "key" is changed into the inferred
    replay buffer class.

    Args:
        config: The replay buffer config to be validated.

    Raises:
        ValueError: When detecting severe misconfiguration.
    """
    if config.get("replay_buffer_config", None) is None:
        config["replay_buffer_config"] = {}

    if config.get("worker_side_prioritization",
                  DEPRECATED_VALUE) != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['worker_side_prioritization']",
            new="config['replay_buffer_config']['worker_side_prioritization']",
            error=True,
        )

    prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
    if prioritized_replay != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['prioritized_replay'] or config['replay_buffer_config']["
            "'prioritized_replay']",
            help=
            "Replay prioritization specified by config key. RLlib's new replay "
            "buffer API requires setting `config["
            "'replay_buffer_config']['type']`, e.g. `config["
            "'replay_buffer_config']['type'] = "
            "'MultiAgentPrioritizedReplayBuffer'` to change the default "
            "behaviour.",
            error=True,
        )

    capacity = config.get("buffer_size", DEPRECATED_VALUE)
    if capacity == DEPRECATED_VALUE:
        capacity = config["replay_buffer_config"].get("buffer_size",
                                                      DEPRECATED_VALUE)
    if capacity != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['buffer_size'] or config['replay_buffer_config']["
            "'buffer_size']",
            new="config['replay_buffer_config']['capacity']",
            error=True,
        )

    replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
    if replay_burn_in != DEPRECATED_VALUE:
        config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
        deprecation_warning(
            old="config['burn_in']",
            help="config['replay_buffer_config']['replay_burn_in']",
        )

    replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
    if replay_batch_size == DEPRECATED_VALUE:
        replay_batch_size = config["replay_buffer_config"].get(
            "replay_batch_size", DEPRECATED_VALUE)
    if replay_batch_size != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['replay_batch_size'] or config['replay_buffer_config']["
            "'replay_batch_size']",
            help=
            "Specification of replay_batch_size is not supported anymore but is "
            "derived from `train_batch_size`. Specify the number of "
            "items you want to replay upon calling the sample() method of replay "
            "buffers if this does not work for you.",
            error=True,
        )

    # Deprecation of old-style replay buffer args
    # Warnings before checking of we need local buffer so that algorithms
    # Without local buffer also get warned
    keys_with_deprecated_positions = [
        "prioritized_replay_alpha",
        "prioritized_replay_beta",
        "prioritized_replay_eps",
        "no_local_replay_buffer",
        "replay_zero_init_states",
        "learning_starts",
        "replay_buffer_shards_colocated_with_driver",
    ]
    for k in keys_with_deprecated_positions:
        if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="config['{}']".format(k),
                help="config['replay_buffer_config']['{}']"
                "".format(k),
                error=False,
            )
            # Copy values over to new location in config to support new
            # and old configuration style.
            if config.get("replay_buffer_config") is not None:
                config["replay_buffer_config"][k] = config[k]

    replay_mode = config.get("multiagent", {}).get("replay_mode",
                                                   DEPRECATED_VALUE)
    if replay_mode != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['multiagent']['replay_mode']",
            help="config['replay_buffer_config']['replay_mode']",
            error=False,
        )
        config["replay_buffer_config"]["replay_mode"] = replay_mode

    # Can't use DEPRECATED_VALUE here because this is also a deliberate
    # value set for some algorithms
    # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
    replay_sequence_length = config.get("replay_sequence_length", None)
    if replay_sequence_length is not None:
        config["replay_buffer_config"][
            "replay_sequence_length"] = replay_sequence_length
        deprecation_warning(
            old="config['replay_sequence_length']",
            help="Replay sequence length specified at new "
            "location config['replay_buffer_config']["
            "'replay_sequence_length'] will be overwritten.",
            error=False,
        )

    replay_buffer_config = config["replay_buffer_config"]
    assert (
        "type" in replay_buffer_config
    ), "Can not instantiate ReplayBuffer from config without 'type' key."

    # Check if old replay buffer should be instantiated
    buffer_type = config["replay_buffer_config"]["type"]

    if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
        # Create valid full [module].[class] string for from_config
        config["replay_buffer_config"]["type"] = (
            "ray.rllib.utils.replay_buffers." + buffer_type)

    # Instantiate a dummy buffer to fail early on misconfiguration and find out about
    # inferred buffer class
    dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])

    config["replay_buffer_config"]["type"] = type(dummy_buffer)

    if hasattr(dummy_buffer, "update_priorities"):
        if config["multiagent"]["replay_mode"] == "lockstep":
            raise ValueError(
                "Prioritized replay is not supported when replay_mode=lockstep."
            )
        elif config["replay_buffer_config"].get("replay_sequence_length",
                                                0) > 1:
            raise ValueError("Prioritized replay is not supported when "
                             "replay_sequence_length > 1.")
    else:
        if config["replay_buffer_config"].get("worker_side_prioritization"):
            raise ValueError(
                "Worker side prioritization is not supported when "
                "prioritized_replay=False.")
Exemplo n.º 3
0
def validate_buffer_config(config: dict):
    if config.get("replay_buffer_config", None) is None:
        config["replay_buffer_config"] = {}

    prioritized_replay = config.get("prioritized_replay")
    if prioritized_replay != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['prioritized_replay']",
            help="Replay prioritization specified at new location config["
            "'replay_buffer_config']["
            "'prioritized_replay'] will be overwritten.",
            error=False,
        )
        config["replay_buffer_config"][
            "prioritized_replay"] = prioritized_replay

    capacity = config.get("buffer_size", DEPRECATED_VALUE)
    if capacity != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['buffer_size']",
            help="Buffer size specified at new location config["
            "'replay_buffer_config']["
            "'capacity'] will be overwritten.",
            error=False,
        )
        config["replay_buffer_config"]["capacity"] = capacity

    # Deprecation of old-style replay buffer args
    # Warnings before checking of we need local buffer so that algorithms
    # Without local buffer also get warned
    deprecated_replay_buffer_keys = [
        "prioritized_replay_alpha",
        "prioritized_replay_beta",
        "prioritized_replay_eps",
        "learning_starts",
    ]
    for k in deprecated_replay_buffer_keys:
        if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="config[{}]".format(k),
                help="config['replay_buffer_config'][{}] should be used "
                "for Q-Learning algorithms. Ignore this warning if "
                "you are not using a Q-Learning algorithm and still "
                "provide {}."
                "".format(k, k),
                error=False,
            )
            # Copy values over to new location in config to support new
            # and old configuration style
            if config.get("replay_buffer_config") is not None:
                config["replay_buffer_config"][k] = config[k]

    # Old Ape-X configs may contain no_local_replay_buffer
    no_local_replay_buffer = config.get("no_local_replay_buffer", False)
    if no_local_replay_buffer:
        deprecation_warning(
            old="config['no_local_replay_buffer']",
            help="no_local_replay_buffer specified at new location config["
            "'replay_buffer_config']["
            "'capacity'] will be overwritten.",
            error=False,
        )
        config["replay_buffer_config"][
            "no_local_replay_buffer"] = no_local_replay_buffer

    # TODO (Artur):
    if config["replay_buffer_config"].get("no_local_replay_buffer", False):
        return

    replay_buffer_config = config["replay_buffer_config"]
    assert (
        "type" in replay_buffer_config
    ), "Can not instantiate ReplayBuffer from config without 'type' key."

    # Check if old replay buffer should be instantiated
    buffer_type = config["replay_buffer_config"]["type"]
    if not config["replay_buffer_config"].get("_enable_replay_buffer_api",
                                              False):
        if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
            # Prepend old-style buffers' path
            assert buffer_type == "MultiAgentReplayBuffer", (
                "Without "
                "ReplayBuffer "
                "API, only "
                "MultiAgentReplayBuffer "
                "is supported!")
            # Create valid full [module].[class] string for from_config
            buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
        else:
            assert buffer_type in [
                "ray.rllib.execution.MultiAgentReplayBuffer",
                Legacy_MultiAgentReplayBuffer,
            ], ("Without ReplayBuffer API, only "
                "MultiAgentReplayBuffer is supported!")

        config["replay_buffer_config"]["type"] = buffer_type

        # Remove from config, so it's not passed into the buffer c'tor
        config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)

        # We need to deprecate the old-style location of the following
        # buffer arguments and make users put them into the
        # "replay_buffer_config" field of their config.
        replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
        if replay_batch_size != DEPRECATED_VALUE:
            config["replay_buffer_config"][
                "replay_batch_size"] = replay_batch_size
            deprecation_warning(
                old="config['replay_batch_size']",
                help="Replay batch size specified at new "
                "location config['replay_buffer_config']["
                "'replay_batch_size'] will be overwritten.",
                error=False,
            )

        replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
        if replay_mode != DEPRECATED_VALUE:
            config["replay_buffer_config"]["replay_mode"] = replay_mode
            deprecation_warning(
                old="config['multiagent']['replay_mode']",
                help="Replay sequence length specified at new "
                "location config['replay_buffer_config']["
                "'replay_mode'] will be overwritten.",
                error=False,
            )

        # Can't use DEPRECATED_VALUE here because this is also a deliberate
        # value set for some algorithms
        # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
        replay_sequence_length = config.get("replay_sequence_length", None)
        if replay_sequence_length is not None:
            config["replay_buffer_config"][
                "replay_sequence_length"] = replay_sequence_length
            deprecation_warning(
                old="config['replay_sequence_length']",
                help="Replay sequence length specified at new "
                "location config['replay_buffer_config']["
                "'replay_sequence_length'] will be overwritten.",
                error=False,
            )

        replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
        if replay_burn_in != DEPRECATED_VALUE:
            config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
            deprecation_warning(
                old="config['burn_in']",
                help="Burn in specified at new location config["
                "'replay_buffer_config']["
                "'replay_burn_in'] will be overwritten.",
            )

        replay_zero_init_states = config.get("replay_zero_init_states",
                                             DEPRECATED_VALUE)
        if replay_zero_init_states != DEPRECATED_VALUE:
            config["replay_buffer_config"][
                "replay_zero_init_states"] = replay_zero_init_states
            deprecation_warning(
                old="config['replay_zero_init_states']",
                help="Replay zero init states specified at new location "
                "config["
                "'replay_buffer_config']["
                "'replay_zero_init_states'] will be overwritten.",
                error=False,
            )

        # TODO (Artur): Move this logic into config objects
        if config["replay_buffer_config"].get("prioritized_replay", False):
            is_prioritized_buffer = True
        else:
            is_prioritized_buffer = False
            # This triggers non-prioritization in old-style replay buffer
            config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
    else:
        if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
            # Create valid full [module].[class] string for from_config
            config["replay_buffer_config"]["type"] = (
                "ray.rllib.utils.replay_buffers." + buffer_type)
        test_buffer = from_config(buffer_type, config["replay_buffer_config"])
        if hasattr(test_buffer, "update_priorities"):
            is_prioritized_buffer = True
        else:
            is_prioritized_buffer = False

    if is_prioritized_buffer:
        if config["multiagent"]["replay_mode"] == "lockstep":
            raise ValueError(
                "Prioritized replay is not supported when replay_mode=lockstep."
            )
        elif config["replay_buffer_config"].get("replay_sequence_length",
                                                0) > 1:
            raise ValueError("Prioritized replay is not supported when "
                             "replay_sequence_length > 1.")
    else:
        if config.get("worker_side_prioritization"):
            raise ValueError(
                "Worker side prioritization is not supported when "
                "prioritized_replay=False.")

    if config["replay_buffer_config"].get("replay_batch_size", None) is None:
        # Fall back to train batch size if no replay batch size was provided
        config["replay_buffer_config"]["replay_batch_size"] = config[
            "train_batch_size"]

    # Pop prioritized replay because it's not a valid parameter for older
    # replay buffers
    config["replay_buffer_config"].pop("prioritized_replay", None)