def __init__( self, num_shards: int = 1, learning_starts: int = 1000, capacity: int = 10000, replay_batch_size: int = 1, prioritized_replay_alpha: float = 0.6, prioritized_replay_beta: float = 0.4, prioritized_replay_eps: float = 1e-6, replay_mode: str = "independent", replay_sequence_length: int = 1, replay_burn_in: int = 0, replay_zero_init_states: bool = True, buffer_size=DEPRECATED_VALUE, ): """Initializes a MultiAgentReplayBuffer instance. Args: num_shards: The number of buffer shards that exist in total (including this one). learning_starts: Number of timesteps after which a call to `replay()` will yield samples (before that, `replay()` will return None). capacity: The capacity of the buffer. Note that when `replay_sequence_length` > 1, this is the number of sequences (not single timesteps) stored. replay_batch_size: The batch size to be sampled (in timesteps). Note that if `replay_sequence_length` > 1, `self.replay_batch_size` will be set to the number of sequences sampled (B). prioritized_replay_alpha: Alpha parameter for a prioritized replay buffer. Use 0.0 for no prioritization. prioritized_replay_beta: Beta parameter for a prioritized replay buffer. prioritized_replay_eps: Epsilon parameter for a prioritized replay buffer. replay_mode: One of "independent" or "lockstep". Determined, whether in the multiagent case, sampling is done across all agents/policies equally. replay_sequence_length: The sequence length (T) of a single sample. If > 1, we will sample B x T from this buffer. replay_burn_in: The burn-in length in case `replay_sequence_length` > 0. This is the number of timesteps each sequence overlaps with the previous one to generate a better internal state (=state after the burn-in), instead of starting from 0.0 each RNN rollout. replay_zero_init_states: Whether the initial states in the buffer (if replay_sequence_length > 0) are alwayas 0.0 or should be updated with the previous train_batch state outputs. """ # Deprecated args. if buffer_size != DEPRECATED_VALUE: deprecation_warning("ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False) capacity = buffer_size self.replay_starts = learning_starts // num_shards self.capacity = capacity // num_shards self.replay_batch_size = replay_batch_size self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.replay_mode = replay_mode self.replay_sequence_length = replay_sequence_length self.replay_burn_in = replay_burn_in self.replay_zero_init_states = replay_zero_init_states if replay_sequence_length > 1: self.replay_batch_size = int( max(1, replay_batch_size // replay_sequence_length)) logger.info( "Since replay_sequence_length={} and replay_batch_size={}, " "we will replay {} sequences at a time.".format( replay_sequence_length, replay_batch_size, self.replay_batch_size)) if replay_mode not in ["lockstep", "independent"]: raise ValueError("Unsupported replay mode: {}".format(replay_mode)) def gen_replay(): while True: yield self.replay() ParallelIteratorWorker.__init__(self, gen_replay, False) def new_buffer(): if prioritized_replay_alpha == 0.0: return ReplayBuffer(self.capacity) else: return PrioritizedReplayBuffer(self.capacity, alpha=prioritized_replay_alpha) self.replay_buffers = collections.defaultdict(new_buffer) # Metrics. self.add_batch_timer = TimerStat() self.replay_timer = TimerStat() self.update_priorities_timer = TimerStat() self.num_added = 0 # Make externally accessible for testing. global _local_replay_buffer _local_replay_buffer = self # If set, return this instead of the usual data for testing. self._fake_batch = None
def validate_buffer_config(config: dict) -> None: """Checks and fixes values in the replay buffer config. Checks the replay buffer config for common misconfigurations, warns or raises error in case validation fails. The type "key" is changed into the inferred replay buffer class. Args: config: The replay buffer config to be validated. Raises: ValueError: When detecting severe misconfiguration. """ if config.get("replay_buffer_config", None) is None: config["replay_buffer_config"] = {} if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE: deprecation_warning( old="config['worker_side_prioritization']", new="config['replay_buffer_config']['worker_side_prioritization']", error=True, ) prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE) if prioritized_replay != DEPRECATED_VALUE: deprecation_warning( old="config['prioritized_replay'] or config['replay_buffer_config'][" "'prioritized_replay']", help= "Replay prioritization specified by config key. RLlib's new replay " "buffer API requires setting `config[" "'replay_buffer_config']['type']`, e.g. `config[" "'replay_buffer_config']['type'] = " "'MultiAgentPrioritizedReplayBuffer'` to change the default " "behaviour.", error=True, ) capacity = config.get("buffer_size", DEPRECATED_VALUE) if capacity == DEPRECATED_VALUE: capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE) if capacity != DEPRECATED_VALUE: deprecation_warning( old="config['buffer_size'] or config['replay_buffer_config'][" "'buffer_size']", new="config['replay_buffer_config']['capacity']", error=True, ) replay_burn_in = config.get("burn_in", DEPRECATED_VALUE) if replay_burn_in != DEPRECATED_VALUE: config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in deprecation_warning( old="config['burn_in']", help="config['replay_buffer_config']['replay_burn_in']", ) replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE) if replay_batch_size == DEPRECATED_VALUE: replay_batch_size = config["replay_buffer_config"].get( "replay_batch_size", DEPRECATED_VALUE) if replay_batch_size != DEPRECATED_VALUE: deprecation_warning( old="config['replay_batch_size'] or config['replay_buffer_config'][" "'replay_batch_size']", help= "Specification of replay_batch_size is not supported anymore but is " "derived from `train_batch_size`. Specify the number of " "items you want to replay upon calling the sample() method of replay " "buffers if this does not work for you.", error=True, ) # Deprecation of old-style replay buffer args # Warnings before checking of we need local buffer so that algorithms # Without local buffer also get warned keys_with_deprecated_positions = [ "prioritized_replay_alpha", "prioritized_replay_beta", "prioritized_replay_eps", "no_local_replay_buffer", "replay_zero_init_states", "learning_starts", "replay_buffer_shards_colocated_with_driver", ] for k in keys_with_deprecated_positions: if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE: deprecation_warning( old="config['{}']".format(k), help="config['replay_buffer_config']['{}']" "".format(k), error=False, ) # Copy values over to new location in config to support new # and old configuration style. if config.get("replay_buffer_config") is not None: config["replay_buffer_config"][k] = config[k] replay_mode = config.get("multiagent", {}).get("replay_mode", DEPRECATED_VALUE) if replay_mode != DEPRECATED_VALUE: deprecation_warning( old="config['multiagent']['replay_mode']", help="config['replay_buffer_config']['replay_mode']", error=False, ) config["replay_buffer_config"]["replay_mode"] = replay_mode # Can't use DEPRECATED_VALUE here because this is also a deliberate # value set for some algorithms # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation replay_sequence_length = config.get("replay_sequence_length", None) if replay_sequence_length is not None: config["replay_buffer_config"][ "replay_sequence_length"] = replay_sequence_length deprecation_warning( old="config['replay_sequence_length']", help="Replay sequence length specified at new " "location config['replay_buffer_config'][" "'replay_sequence_length'] will be overwritten.", error=False, ) replay_buffer_config = config["replay_buffer_config"] assert ( "type" in replay_buffer_config ), "Can not instantiate ReplayBuffer from config without 'type' key." # Check if old replay buffer should be instantiated buffer_type = config["replay_buffer_config"]["type"] if isinstance(buffer_type, str) and buffer_type.find(".") == -1: # Create valid full [module].[class] string for from_config config["replay_buffer_config"]["type"] = ( "ray.rllib.utils.replay_buffers." + buffer_type) # Instantiate a dummy buffer to fail early on misconfiguration and find out about # inferred buffer class dummy_buffer = from_config(buffer_type, config["replay_buffer_config"]) config["replay_buffer_config"]["type"] = type(dummy_buffer) if hasattr(dummy_buffer, "update_priorities"): if config["multiagent"]["replay_mode"] == "lockstep": raise ValueError( "Prioritized replay is not supported when replay_mode=lockstep." ) elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1: raise ValueError("Prioritized replay is not supported when " "replay_sequence_length > 1.") else: if config["replay_buffer_config"].get("worker_side_prioritization"): raise ValueError( "Worker side prioritization is not supported when " "prioritized_replay=False.")
def validate_buffer_config(config: dict): if config.get("replay_buffer_config", None) is None: config["replay_buffer_config"] = {} prioritized_replay = config.get("prioritized_replay") if prioritized_replay != DEPRECATED_VALUE: deprecation_warning( old="config['prioritized_replay']", help="Replay prioritization specified at new location config[" "'replay_buffer_config'][" "'prioritized_replay'] will be overwritten.", error=False, ) config["replay_buffer_config"][ "prioritized_replay"] = prioritized_replay capacity = config.get("buffer_size", DEPRECATED_VALUE) if capacity != DEPRECATED_VALUE: deprecation_warning( old="config['buffer_size']", help="Buffer size specified at new location config[" "'replay_buffer_config'][" "'capacity'] will be overwritten.", error=False, ) config["replay_buffer_config"]["capacity"] = capacity # Deprecation of old-style replay buffer args # Warnings before checking of we need local buffer so that algorithms # Without local buffer also get warned deprecated_replay_buffer_keys = [ "prioritized_replay_alpha", "prioritized_replay_beta", "prioritized_replay_eps", "learning_starts", ] for k in deprecated_replay_buffer_keys: if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE: deprecation_warning( old="config[{}]".format(k), help="config['replay_buffer_config'][{}] should be used " "for Q-Learning algorithms. Ignore this warning if " "you are not using a Q-Learning algorithm and still " "provide {}." "".format(k, k), error=False, ) # Copy values over to new location in config to support new # and old configuration style if config.get("replay_buffer_config") is not None: config["replay_buffer_config"][k] = config[k] # Old Ape-X configs may contain no_local_replay_buffer no_local_replay_buffer = config.get("no_local_replay_buffer", False) if no_local_replay_buffer: deprecation_warning( old="config['no_local_replay_buffer']", help="no_local_replay_buffer specified at new location config[" "'replay_buffer_config'][" "'capacity'] will be overwritten.", error=False, ) config["replay_buffer_config"][ "no_local_replay_buffer"] = no_local_replay_buffer # TODO (Artur): if config["replay_buffer_config"].get("no_local_replay_buffer", False): return replay_buffer_config = config["replay_buffer_config"] assert ( "type" in replay_buffer_config ), "Can not instantiate ReplayBuffer from config without 'type' key." # Check if old replay buffer should be instantiated buffer_type = config["replay_buffer_config"]["type"] if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False): if isinstance(buffer_type, str) and buffer_type.find(".") == -1: # Prepend old-style buffers' path assert buffer_type == "MultiAgentReplayBuffer", ( "Without " "ReplayBuffer " "API, only " "MultiAgentReplayBuffer " "is supported!") # Create valid full [module].[class] string for from_config buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer" else: assert buffer_type in [ "ray.rllib.execution.MultiAgentReplayBuffer", Legacy_MultiAgentReplayBuffer, ], ("Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!") config["replay_buffer_config"]["type"] = buffer_type # Remove from config, so it's not passed into the buffer c'tor config["replay_buffer_config"].pop("_enable_replay_buffer_api", None) # We need to deprecate the old-style location of the following # buffer arguments and make users put them into the # "replay_buffer_config" field of their config. replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE) if replay_batch_size != DEPRECATED_VALUE: config["replay_buffer_config"][ "replay_batch_size"] = replay_batch_size deprecation_warning( old="config['replay_batch_size']", help="Replay batch size specified at new " "location config['replay_buffer_config'][" "'replay_batch_size'] will be overwritten.", error=False, ) replay_mode = config.get("replay_mode", DEPRECATED_VALUE) if replay_mode != DEPRECATED_VALUE: config["replay_buffer_config"]["replay_mode"] = replay_mode deprecation_warning( old="config['multiagent']['replay_mode']", help="Replay sequence length specified at new " "location config['replay_buffer_config'][" "'replay_mode'] will be overwritten.", error=False, ) # Can't use DEPRECATED_VALUE here because this is also a deliberate # value set for some algorithms # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation replay_sequence_length = config.get("replay_sequence_length", None) if replay_sequence_length is not None: config["replay_buffer_config"][ "replay_sequence_length"] = replay_sequence_length deprecation_warning( old="config['replay_sequence_length']", help="Replay sequence length specified at new " "location config['replay_buffer_config'][" "'replay_sequence_length'] will be overwritten.", error=False, ) replay_burn_in = config.get("burn_in", DEPRECATED_VALUE) if replay_burn_in != DEPRECATED_VALUE: config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in deprecation_warning( old="config['burn_in']", help="Burn in specified at new location config[" "'replay_buffer_config'][" "'replay_burn_in'] will be overwritten.", ) replay_zero_init_states = config.get("replay_zero_init_states", DEPRECATED_VALUE) if replay_zero_init_states != DEPRECATED_VALUE: config["replay_buffer_config"][ "replay_zero_init_states"] = replay_zero_init_states deprecation_warning( old="config['replay_zero_init_states']", help="Replay zero init states specified at new location " "config[" "'replay_buffer_config'][" "'replay_zero_init_states'] will be overwritten.", error=False, ) # TODO (Artur): Move this logic into config objects if config["replay_buffer_config"].get("prioritized_replay", False): is_prioritized_buffer = True else: is_prioritized_buffer = False # This triggers non-prioritization in old-style replay buffer config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0 else: if isinstance(buffer_type, str) and buffer_type.find(".") == -1: # Create valid full [module].[class] string for from_config config["replay_buffer_config"]["type"] = ( "ray.rllib.utils.replay_buffers." + buffer_type) test_buffer = from_config(buffer_type, config["replay_buffer_config"]) if hasattr(test_buffer, "update_priorities"): is_prioritized_buffer = True else: is_prioritized_buffer = False if is_prioritized_buffer: if config["multiagent"]["replay_mode"] == "lockstep": raise ValueError( "Prioritized replay is not supported when replay_mode=lockstep." ) elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1: raise ValueError("Prioritized replay is not supported when " "replay_sequence_length > 1.") else: if config.get("worker_side_prioritization"): raise ValueError( "Worker side prioritization is not supported when " "prioritized_replay=False.") if config["replay_buffer_config"].get("replay_batch_size", None) is None: # Fall back to train batch size if no replay batch size was provided config["replay_buffer_config"]["replay_batch_size"] = config[ "train_batch_size"] # Pop prioritized replay because it's not a valid parameter for older # replay buffers config["replay_buffer_config"].pop("prioritized_replay", None)