class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__( self, local_worker: RolloutWorker, minibatch_buffer_size: int, num_sgd_iter: int, learner_queue_size: int, learner_queue_timeout: int, ): """Initialize the learner thread. Args: local_worker: process local rollout worker holding policies this thread will call learn_on_batch() on minibatch_buffer_size: max number of train batches to store in the minibatching buffer num_sgd_iter: number of passes to learn on per train batch learner_queue_size: max size of queue of inbound train batches to this thread learner_queue_timeout: raise an exception if the queue has been empty for this long in seconds """ threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer( inqueue=self.inqueue, size=minibatch_buffer_size, timeout=learner_queue_timeout, num_passes=num_sgd_iter, init_num_passes=num_sgd_iter, ) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.learner_info = {} self.stopped = False self.num_steps = 0 def run(self) -> None: # Switch on eager mode if configured. if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]: tf1.enable_eager_execution() while not self.stopped: self.step() def step(self) -> Optional[_NextValueNotReady]: with self.queue_timer: try: batch, _ = self.minibatch_buffer.get() except queue.Empty: return _NextValueNotReady() with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch(batch) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results(results, pid) self.learner_info = learner_info_builder.finalize() self.weights_updated = True self.num_steps += 1 # Put tuple: env-steps, agent-steps, and learner info into the queue. self.outqueue.put( (batch.count, batch.agent_steps(), self.learner_info)) self.learner_queue_size.push(self.inqueue.qsize()) def add_learner_metrics(self, result: Dict, overwrite_learner_info=True) -> Dict: """Add internal metrics to a result dict.""" def timer_to_ms(timer): return round(1000 * timer.mean, 3) if overwrite_learner_info: result["info"].update({ "learner_queue": self.learner_queue_size.stats(), LEARNER_INFO: copy.deepcopy(self.learner_info), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.queue_timer), }, }) else: result["info"].update({ "learner_queue": self.learner_queue_size.stats(), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.queue_timer), }, }) return result
class ReplayBuffer: def __init__(self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs): """Initializes a (FIFO) ReplayBuffer instance. Args: capacity: Max number of timesteps to store in this FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. storage_unit: Either 'timesteps', `sequences` or `episodes`. Specifies how experiences are stored. **kwargs: Forward compatibility kwargs. """ if storage_unit in ["timesteps", StorageUnit.TIMESTEPS]: self._storage_unit = StorageUnit.TIMESTEPS elif storage_unit in ["sequences", StorageUnit.SEQUENCES]: self._storage_unit = StorageUnit.SEQUENCES elif storage_unit in ["episodes", StorageUnit.EPISODES]: self._storage_unit = StorageUnit.EPISODES else: raise ValueError( "storage_unit must be either 'timesteps', `sequences` or `episodes`." ) # The actual storage (list of SampleBatches or MultiAgentBatches). self._storage = [] # Caps the number of timesteps stored in this buffer self.capacity = capacity # The next index to override in the buffer. self._next_idx = 0 # len(self._hit_count) must always be less than len(capacity) self._hit_count = np.zeros(self.capacity) # Whether we have already hit our capacity (and have therefore # started to evict older samples). self._eviction_started = False # Number of (single) timesteps that have been added to the buffer # over its lifetime. Note that each added item (batch) may contain # more than one timestep. self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 # Number of (single) timesteps that have been sampled from the buffer # over its lifetime. self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 self.batch_size = None def __len__(self) -> int: """Returns the number of items currently stored in this buffer.""" return len(self._storage) @ExperimentalAPI def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences to this buffer. Also splits experiences into chunks of timesteps, sequences or episodes, depending on self._storage_unit. Calls self._add_single_batch. Args: batch: Batch to add to this buffer's storage. **kwargs: Forward compatibility kwargs. """ assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) if (type(batch) == MultiAgentBatch and self._storage_unit != StorageUnit.TIMESTEPS): raise ValueError("Can not add MultiAgentBatch to ReplayBuffer " "with storage_unit {}" "".format(str(self._storage_unit))) if self._storage_unit == StorageUnit.TIMESTEPS: self._add_single_batch(batch, **kwargs) elif self._storage_unit == StorageUnit.SEQUENCES: timestep_count = 0 for seq_len in batch.get(SampleBatch.SEQ_LENS): start_seq = timestep_count end_seq = timestep_count + seq_len self._add_single_batch(batch[start_seq:end_seq], **kwargs) timestep_count = end_seq elif self._storage_unit == StorageUnit.EPISODES: for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer self._add_single_batch(eps, **kwargs) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") @ExperimentalAPI def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a SampleBatch of experiences to self._storage. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The batch to be added. **kwargs: Forward compatibility kwargs. """ self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 @ExperimentalAPI def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]: """Samples `num_items` items from this buffer. Samples in the results may be repeated. Examples for storage of SamplesBatches: - If storage unit `timesteps` has been chosen and batches of size 5 have been added, sample(5) will yield a concatenated batch of 15 timesteps. - If storage unit 'sequences' has been chosen and sequences of different lengths have been added, sample(5) will yield a concatenated batch with a number of timesteps equal to the sum of timesteps in the 5 sampled sequences. - If storage unit 'episodes' has been chosen and episodes of different lengths have been added, sample(5) will yield a concatenated batch with a number of timesteps equal to the sum of timesteps in the 5 sampled episodes. Args: num_items: Number of items to sample from this buffer. **kwargs: Forward compatibility kwargs. Returns: Concatenated batch of items. """ idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] sample = self._encode_sample(idxes) self._num_timesteps_sampled += sample.count return sample @ExperimentalAPI def stats(self, debug: bool = False) -> dict: """Returns the stats of this buffer. Args: debug: If True, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ data = { "added_count": self._num_timesteps_added, "added_count_wrapped": self._num_timesteps_added_wrap, "eviction_started": self._eviction_started, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data @ExperimentalAPI def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ state = {"_storage": self._storage, "_next_idx": self._next_idx} state.update(self.stats(debug=False)) return state @ExperimentalAPI def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ # The actual storage. self._storage = state["_storage"] self._next_idx = state["_next_idx"] # Stats and counts. self._num_timesteps_added = state["added_count"] self._num_timesteps_added_wrap = state["added_count_wrapped"] self._eviction_started = state["eviction_started"] self._num_timesteps_sampled = state["sampled_count"] self._est_size_bytes = state["est_size_bytes"] def _encode_sample(self, idxes: List[int]) -> SampleBatchType: """Fetches concatenated samples at given indeces from the storage.""" samples = [self._storage[i] for i in idxes] if samples: # Assume all samples are of same type sample_type = type(samples[0]) out = sample_type.concat_samples(samples) else: out = SampleBatch() out.decompress_if_needed() return out def get_host(self) -> str: """Returns the computer's network name. Returns: The computer's networks name or an empty string, if the network name could not be determined. """ return platform.node()
class ReplayBuffer(ParallelIteratorWorker): """The lowest-level replay buffer interface used by RLlib. This class implements a basic ring-type of buffer with random sampling. ReplayBuffer is the base class for advanced types that add functionality while retaining compatibility through inheritance. The following examples show how buffers behave with different storage_units and capacities. This behaviour is generally similar for other buffers, although they might not implement all storage_units. Examples: >>> from ray.rllib.utils.replay_buffers import ReplayBuffer, # doctest: +SKIP ... StorageUnit # doctest: +SKIP >>> from ray.rllib.policy.sample_batch import SampleBatch # doctest: +SKIP >>> # Store any batch as a whole >>> buffer = ReplayBuffer(capacity=10, ... storage_unit=StorageUnit.FRAGMENTS) # doctest: +SKIP >>> buffer.add(SampleBatch({"a": [1], "b": [2, 3, 4]})) # doctest: +SKIP >>> print(buffer.sample(1)) # doctest: +SKIP >>> # SampleBatch(1: ['a', 'b']) >>> # Store only complete episodes >>> buffer = ReplayBuffer(capacity=10, ... storage_unit=StorageUnit.EPISODES) # doctest: +SKIP >>> buffer.add(SampleBatch({"c": [1, 2, 3, 4], # doctest: +SKIP ... SampleBatch.T: [0, 1, 0, 1], ... SampleBatch.DONES: [False, True, False, True], ... SampleBatch.EPS_ID: [0, 0, 1, 1]})) # doctest: +SKIP >>> eps_n = buffer.sample(1) # doctest: +SKIP >>> print(eps_n[SampleBatch.EPS_ID]) # doctest: +SKIP >>> # [1 1] >>> # Store single timesteps >>> buffer = ReplayBuffer(capacity=2, # doctest: +SKIP ... storage_unit=StorageUnit.TIMESTEPS) # doctest: +SKIP >>> buffer.add(SampleBatch({"a": [1, 2], ... SampleBatch.T: [0, 1]})) # doctest: +SKIP >>> t_n = buffer.sample(1) # doctest: +SKIP >>> print(t_n["a"]) # doctest: +SKIP >>> # [2] >>> buffer.add(SampleBatch({"a": [3], SampleBatch.T: [2]})) # doctest: +SKIP >>> print(buffer._eviction_started) # doctest: +SKIP >>> # True >>> t_n = buffer.sample(1) # doctest: +SKIP >>> print(t_n["a"]) # doctest: +SKIP >>> # [3] # doctest: +SKIP >>> buffer = ReplayBuffer(capacity=10, # doctest: +SKIP ... storage_unit=StorageUnit.SEQUENCES) # doctest: +SKIP >>> buffer.add(SampleBatch({"c": [1, 2, 3], # doctest: +SKIP ... SampleBatch.SEQ_LENS: [1, 2]})) # doctest: +SKIP >>> seq_n = buffer.sample(1) # doctest: +SKIP >>> print(seq_n["c"]) # doctest: +SKIP >>> # [1] """ def __init__( self, capacity: int = 10000, storage_unit: Union[str, StorageUnit] = "timesteps", **kwargs, ): """Initializes a (FIFO) ReplayBuffer instance. Args: capacity: Max number of timesteps to store in this FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. storage_unit: If not a StorageUnit, either 'timesteps', 'sequences' or 'episodes'. Specifies how experiences are stored. ``**kwargs``: Forward compatibility kwargs. """ if storage_unit in ["timesteps", StorageUnit.TIMESTEPS]: self.storage_unit = StorageUnit.TIMESTEPS elif storage_unit in ["sequences", StorageUnit.SEQUENCES]: self.storage_unit = StorageUnit.SEQUENCES elif storage_unit in ["episodes", StorageUnit.EPISODES]: self.storage_unit = StorageUnit.EPISODES elif storage_unit in ["fragments", StorageUnit.FRAGMENTS]: self.storage_unit = StorageUnit.FRAGMENTS else: raise ValueError( "storage_unit must be either 'timesteps', 'sequences' or 'episodes' " "or 'fragments', but is {}".format(storage_unit)) # The actual storage (list of SampleBatches or MultiAgentBatches). self._storage = [] # Caps the number of timesteps stored in this buffer if capacity <= 0: raise ValueError( "Capacity of replay buffer has to be greater than zero " "but was set to {}.".format(capacity)) self.capacity = capacity # The next index to override in the buffer. self._next_idx = 0 # len(self._hit_count) must always be less than len(capacity) self._hit_count = np.zeros(self.capacity) # Whether we have already hit our capacity (and have therefore # started to evict older samples). self._eviction_started = False # Number of (single) timesteps that have been added to the buffer # over its lifetime. Note that each added item (batch) may contain # more than one timestep. self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 # Number of (single) timesteps that have been sampled from the buffer # over its lifetime. self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 self.batch_size = None def __len__(self) -> int: """Returns the number of items currently stored in this buffer.""" return len(self._storage) @DeveloperAPI def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences to this buffer. Splits batch into chunks of timesteps, sequences or episodes, depending on `self._storage_unit`. Calls `self._add_single_batch` to add resulting slices to the buffer storage. Args: batch: Batch to add. ``**kwargs``: Forward compatibility kwargs. """ if not batch.count > 0: return warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) if self.storage_unit == StorageUnit.TIMESTEPS: timeslices = batch.timeslices(1) for t in timeslices: self._add_single_batch(t, **kwargs) elif self.storage_unit == StorageUnit.SEQUENCES: timestep_count = 0 for seq_len in batch.get(SampleBatch.SEQ_LENS): start_seq = timestep_count end_seq = timestep_count + seq_len self._add_single_batch(batch[start_seq:end_seq], **kwargs) timestep_count = end_seq elif self.storage_unit == StorageUnit.EPISODES: for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer self._add_single_batch(eps, **kwargs) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") elif self.storage_unit == StorageUnit.FRAGMENTS: self._add_single_batch(batch, **kwargs) @DeveloperAPI def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a SampleBatch of experiences to self._storage. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The batch to be added. ``**kwargs``: Forward compatibility kwargs. """ self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: item_to_be_removed = self._storage[self._next_idx] self._est_size_bytes -= item_to_be_removed.size_bytes() self._storage[self._next_idx] = item self._est_size_bytes += item.size_bytes() # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 @DeveloperAPI def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]: """Samples `num_items` items from this buffer. The items depend on the buffer's storage_unit. Samples in the results may be repeated. Examples for sampling results: 1) If storage unit 'timesteps' has been chosen and batches of size 5 have been added, sample(5) will yield a concatenated batch of 15 timesteps. 2) If storage unit 'sequences' has been chosen and sequences of different lengths have been added, sample(5) will yield a concatenated batch with a number of timesteps equal to the sum of timesteps in the 5 sampled sequences. 3) If storage unit 'episodes' has been chosen and episodes of different lengths have been added, sample(5) will yield a concatenated batch with a number of timesteps equal to the sum of timesteps in the 5 sampled episodes. Args: num_items: Number of items to sample from this buffer. ``**kwargs``: Forward compatibility kwargs. Returns: Concatenated batch of items. """ if len(self) == 0: raise ValueError("Trying to sample from an empty buffer.") idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] sample = self._encode_sample(idxes) self._num_timesteps_sampled += sample.count return sample @DeveloperAPI def stats(self, debug: bool = False) -> dict: """Returns the stats of this buffer. Args: debug: If True, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ data = { "added_count": self._num_timesteps_added, "added_count_wrapped": self._num_timesteps_added_wrap, "eviction_started": self._eviction_started, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data @DeveloperAPI def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ state = {"_storage": self._storage, "_next_idx": self._next_idx} state.update(self.stats(debug=False)) return state @DeveloperAPI def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ # The actual storage. self._storage = state["_storage"] self._next_idx = state["_next_idx"] # Stats and counts. self._num_timesteps_added = state["added_count"] self._num_timesteps_added_wrap = state["added_count_wrapped"] self._eviction_started = state["eviction_started"] self._num_timesteps_sampled = state["sampled_count"] self._est_size_bytes = state["est_size_bytes"] @DeveloperAPI def _encode_sample(self, idxes: List[int]) -> SampleBatchType: """Fetches concatenated samples at given indices from the storage.""" samples = [] for i in idxes: self._hit_count[i] += 1 samples.append(self._storage[i]) if samples: # We assume all samples are of same type sample_type = type(samples[0]) out = sample_type.concat_samples(samples) else: out = SampleBatch() out.decompress_if_needed() return out @DeveloperAPI def get_host(self) -> str: """Returns the computer's network name. Returns: The computer's networks name or an empty string, if the network name could not be determined. """ return platform.node() @DeveloperAPI def apply( self, func: Callable[["ReplayBuffer", Optional[Any], Optional[Any]], T], *args, **kwargs, ) -> T: """Calls the given function with this ReplayBuffer instance. This is useful if we want to apply a function to a set of remote actors. Args: func: A callable that accepts the replay buffer itself, args and kwargs ``*args``: Any args to pass to func ``**kwargs``: Any kwargs to pass to func Returns: Return value of the induced function call """ return func(self, *args, **kwargs) @Deprecated(old="ReplayBuffer.add_batch()", new="ReplayBuffer.add()", error=False) def add_batch(self, *args, **kwargs): return self.add(*args, **kwargs) @Deprecated( old="ReplayBuffer.replay(num_items)", new="ReplayBuffer.sample(num_items)", error=False, ) def replay(self, num_items): return self.sample(num_items) @Deprecated( help="ReplayBuffers could be iterated over by default before. " "Making a buffer an iterator will soon " "be deprecated altogether. Consider switching to the training " "iteration API to resolve this.", error=False, ) def make_iterator(self, num_items_to_replay: int): """Make this buffer a ParallelIteratorWorker to retain compatibility. Execution plans have made heavy use of buffers as ParallelIteratorWorkers. This method provides an easy way to support this for now. """ def gen_replay(): while True: yield self.sample(num_items_to_replay) ParallelIteratorWorker.__init__(self, gen_replay, False)
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.overall_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False self.learner_info = {} def run(self): # Switch on eager mode if configured. if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]: tf1.enable_eager_execution() while not self.stopped: self.step() def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the # final results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same # structure no matter the setup (multi-GPU, multi-agent, # minibatch SGD, tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch( replay) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results( results, pid) td_error = results["td_error"] # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing # (may get manipulated if they are part of a tensor). replay.policy_batches[pid].set_get_interceptor(None) prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error, ) self.learner_info = learner_info_builder.finalize() self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
class PrioritizedReplayBuffer(ReplayBuffer): """This buffer implements Prioritized Experience Replay The algorithm has been described by Tom Schaul et. al. in "Prioritized Experience Replay". See https://arxiv.org/pdf/1511.05952.pdf for the full paper. """ @ExperimentalAPI def __init__(self, capacity: int = 10000, storage_unit: str = "timesteps", alpha: float = 1.0, **kwargs): """Initializes a PrioritizedReplayBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. storage_unit: Either 'timesteps', 'sequences' or 'episodes'. Specifies how experiences are stored. alpha: How much prioritization is used (0.0=no prioritization, 1.0=full prioritization). **kwargs: Forward compatibility kwargs. """ ReplayBuffer.__init__(self, capacity, storage_unit, **kwargs) assert alpha > 0 self._alpha = alpha # Segment tree must have capacity that is a power of 2 it_capacity = 1 while it_capacity < self.capacity: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) @ExperimentalAPI @override(ReplayBuffer) def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a batch of experiences to self._storage with weight. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The item to be added. **kwargs: Forward compatibility kwargs. """ weight = kwargs.get("weight", None) if weight is None: weight = self._max_priority self._it_sum[self._next_idx] = weight**self._alpha self._it_min[self._next_idx] = weight**self._alpha ReplayBuffer._add_single_batch(self, item) def _sample_proportional(self, num_items: int) -> List[int]: res = [] for _ in range(num_items): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @ExperimentalAPI @override(ReplayBuffer) def sample(self, num_items: int, beta: float, **kwargs) -> Optional[SampleBatchType]: """Sample `num_items` items from this buffer, including prio. weights. Samples in the results may be repeated. Examples for storage of SamplesBatches: - If storage unit `timesteps` has been chosen and batches of size 5 have been added, sample(5) will yield a concatenated batch of 15 timesteps. - If storage unit 'sequences' has been chosen and sequences of different lengths have been added, sample(5) will yield a concatenated batch with a number of timesteps equal to the sum of timesteps in the 5 sampled sequences. - If storage unit 'episodes' has been chosen and episodes of different lengths have been added, sample(5) will yield a concatenated batch with a number of timesteps equal to the sum of timesteps in the 5 sampled episodes. Args: num_items: Number of items to sample from this buffer. beta: To what degree to use importance weights (0 - no corrections, 1 - full correction). **kwargs: Forward compatibility kwargs. Returns: Concatenated SampleBatch of items including "weights" and "batch_indexes" fields denoting IS of each sampled transition and original idxes in buffer of sampled experiences. """ assert beta >= 0.0 idxes = self._sample_proportional(num_items) weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self))**(-beta) count = self._storage[idx].count # If zero-padded, count will not be the actual batch size of the # data. if (isinstance(self._storage[idx], SampleBatch) and self._storage[idx].zero_padded): actual_size = self._storage[idx].max_seq_len else: actual_size = count weights.extend([weight / max_weight] * actual_size) batch_indexes.extend([idx] * actual_size) self._num_timesteps_sampled += count batch = self._encode_sample(idxes) # Note: prioritization is not supported in multi agent lockstep if isinstance(batch, SampleBatch): batch["weights"] = np.array(weights) batch["batch_indexes"] = np.array(batch_indexes) return batch @ExperimentalAPI def update_priorities(self, idxes: List[int], priorities: List[float]) -> None: """Update priorities of items at given indices. Sets priority of item at index idxes[i] in buffer to priorities[i]. Args: idxes: List of indices of items priorities: List of updated priorities corresponding to items at the idxes denoted by variable `idxes`. """ # Making sure we don't pass in e.g. a torch tensor. assert isinstance( idxes, (list, np.ndarray) ), "ERROR: `idxes` is not a list or np.ndarray, but " "{}!".format( type(idxes).__name__) assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) @ExperimentalAPI @override(ReplayBuffer) def stats(self, debug: bool = False) -> Dict: """Returns the stats of this buffer. Args: debug: If true, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent @ExperimentalAPI @override(ReplayBuffer) def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ # Get parent state. state = super().get_state() # Add prio weights. state.update({ "sum_segment_tree": self._it_sum.get_state(), "min_segment_tree": self._it_min.get_state(), "max_priority": self._max_priority, }) return state @ExperimentalAPI @override(ReplayBuffer) def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ super().set_state(state) self._it_sum.set_state(state["sum_segment_tree"]) self._it_min.set_state(state["min_segment_tree"]) self._max_priority = state["max_priority"]
class PrioritizedReplayBuffer(ReplayBuffer): @ExperimentalAPI def __init__( self, capacity: int = 10000, storage_unit: str = "timesteps", alpha: float = 1.0, ): """Initializes a PrioritizedReplayBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. storage_unit: Either 'sequences' or 'timesteps'. Specifies how experiences are stored. alpha: How much prioritization is used (0.0=no prioritization, 1.0=full prioritization). """ ReplayBuffer.__init__(self, capacity, storage_unit) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < self.capacity: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) @ExperimentalAPI @override(ReplayBuffer) def add(self, batch: SampleBatchType, weight: float) -> None: """Add a batch of experiences. Args: batch: SampleBatch to add to this buffer's storage. weight: The weight of the added sample used in subsequent sampling steps. """ idx = self._next_idx assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) # Update our timesteps counts. self._num_timesteps_added += batch.count self._num_timesteps_added_wrap += batch.count if self._next_idx >= len(self._storage): self._storage.append(batch) self._est_size_bytes += batch.size_bytes() else: self._storage[self._next_idx] = batch # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, num_items: int) -> List[int]: res = [] for _ in range(num_items): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @ExperimentalAPI @override(ReplayBuffer) def sample(self, num_items: int, beta: float) -> Optional[SampleBatchType]: """Sample `num_items` items from this buffer, including prio. weights. If less than `num_items` records are in this buffer, some samples in the results may be repeated to fulfil the batch size (`num_items`) request. Args: num_items: Number of items to sample from this buffer. beta: To what degree to use importance weights (0 - no corrections, 1 - full correction). Returns: Concatenated batch of items including "weights" and "batch_indexes" fields denoting IS of each sampled transition and original idxes in buffer of sampled experiences. """ # If we don't have any samples yet in this buffer, return None. if len(self) == 0: return None assert beta >= 0.0 idxes = self._sample_proportional(num_items) weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self))**(-beta) count = self._storage[idx].count # If zero-padded, count will not be the actual batch size of the # data. if (isinstance(self._storage[idx], SampleBatch) and self._storage[idx].zero_padded): actual_size = self._storage[idx].max_seq_len else: actual_size = count weights.extend([weight / max_weight] * actual_size) batch_indexes.extend([idx] * actual_size) self._num_timesteps_sampled += count batch = self._encode_sample(idxes) # Note: prioritization is not supported in lockstep replay mode. if isinstance(batch, SampleBatch): batch["weights"] = np.array(weights) batch["batch_indexes"] = np.array(batch_indexes) return batch @ExperimentalAPI def update_priorities(self, idxes: List[int], priorities: List[float]) -> None: """Update priorities of sampled transitions. Sets priority of transition at index idxes[i] in buffer to priorities[i]. Args: idxes: List of indices of sampled transitions priorities: List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ # Making sure we don't pass in e.g. a torch tensor. assert isinstance( idxes, (list, np.ndarray) ), "ERROR: `idxes` is not a list or np.ndarray, but " "{}!".format( type(idxes).__name__) assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) @ExperimentalAPI @override(ReplayBuffer) def stats(self, debug: bool = False) -> Dict: """Returns the stats of this buffer. Args: debug: If true, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent @ExperimentalAPI @override(ReplayBuffer) def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ # Get parent state. state = super().get_state() # Add prio weights. state.update({ "sum_segment_tree": self._it_sum.get_state(), "min_segment_tree": self._it_min.get_state(), "max_priority": self._max_priority, }) return state @ExperimentalAPI @override(ReplayBuffer) def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ super().set_state(state) self._it_sum.set_state(state["sum_segment_tree"]) self._it_min.set_state(state["min_segment_tree"]) self._max_priority = state["max_priority"]
class ReplayBuffer: @DeveloperAPI def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE): """Initializes a ReplayBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. """ # Deprecated args. if size != DEPRECATED_VALUE: deprecation_warning("ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False) capacity = size # The actual storage (list of SampleBatches). self._storage = [] self.capacity = capacity # The next index to override in the buffer. self._next_idx = 0 self._hit_count = np.zeros(self.capacity) # Whether we have already hit our capacity (and have therefore # started to evict older samples). self._eviction_started = False # Number of (single) timesteps that have been added to the buffer # over its lifetime. Note that each added item (batch) may contain # more than one timestep. self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 # Number of (single) timesteps that have been sampled from the buffer # over its lifetime. self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self) -> int: """Returns the number of items currently stored in this buffer.""" return len(self._storage) @DeveloperAPI def add(self, item: SampleBatchType, weight: float) -> None: """Add a batch of experiences. Args: item: SampleBatch to add to this buffer's storage. weight: The weight of the added sample used in subsequent sampling steps. Only relevant if this ReplayBuffer is a PrioritizedReplayBuffer. """ assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count) # Update our timesteps counts. self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 @DeveloperAPI def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType: """Sample a batch of size `num_items` from this buffer. If less than `num_items` records are in this buffer, some samples in the results may be repeated to fulfil the batch size (`num_items`) request. Args: num_items: Number of items to sample from this buffer. beta: The prioritized replay beta value. Only relevant if this ReplayBuffer is a PrioritizedReplayBuffer. Returns: Concatenated batch of items. """ # If we don't have any samples yet in this buffer, return None. if len(self) == 0: return None idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] sample = self._encode_sample(idxes) # Update our timesteps counters. self._num_timesteps_sampled += len(sample) return sample @DeveloperAPI def stats(self, debug: bool = False) -> dict: """Returns the stats of this buffer. Args: debug: If True, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ data = { "added_count": self._num_timesteps_added, "added_count_wrapped": self._num_timesteps_added_wrap, "eviction_started": self._eviction_started, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data @DeveloperAPI def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ state = {"_storage": self._storage, "_next_idx": self._next_idx} state.update(self.stats(debug=False)) return state @DeveloperAPI def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ # The actual storage. self._storage = state["_storage"] self._next_idx = state["_next_idx"] # Stats and counts. self._num_timesteps_added = state["added_count"] self._num_timesteps_added_wrap = state["added_count_wrapped"] self._eviction_started = state["eviction_started"] self._num_timesteps_sampled = state["sampled_count"] self._est_size_bytes = state["est_size_bytes"] def _encode_sample(self, idxes: List[int]) -> SampleBatchType: out = SampleBatch.concat_samples([self._storage[i] for i in idxes]) out.decompress_if_needed() return out
class ReplayBuffer: @DeveloperAPI def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE): """Initializes a Replaybuffer instance. Args: capacity (int): Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. """ # Deprecated args. if size != DEPRECATED_VALUE: deprecation_warning("ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False) capacity = size # The actual storage (list of SampleBatches). self._storage = [] self.capacity = capacity # The next index to override in the buffer. self._next_idx = 0 self._hit_count = np.zeros(self.capacity) # Whether we have already hit our capacity (and have therefore # started to evict older samples). self._eviction_started = False self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self) -> int: return len(self._storage) @DeveloperAPI def add(self, item: SampleBatchType, weight: float) -> None: assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count) self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes: List[int]) -> SampleBatchType: out = SampleBatch.concat_samples([self._storage[i] for i in idxes]) out.decompress_if_needed() return out @DeveloperAPI def sample(self, num_items: int) -> SampleBatchType: """Sample a batch of experiences. Args: num_items (int): Number of items to sample from this buffer. Returns: SampleBatchType: concatenated batch of items. """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(num_items) ] self._num_sampled += num_items return self._encode_sample(idxes) @DeveloperAPI def stats(self, debug=False) -> dict: data = { "added_count": self._num_timesteps_added, "added_count_wrapped": self._num_timesteps_added_wrap, "eviction_started": self._eviction_started, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data @DeveloperAPI def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: Dict[str, Any]: The serializable local state. """ state = {"_storage": self._storage, "_next_idx": self._next_idx} state.update(self.stats(debug=False)) return state @DeveloperAPI def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state (Dict[str, Any]): The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ # The actual storage. self._storage = state["_storage"] self._next_idx = state["_next_idx"] # Stats and counts. self._num_timesteps_added = state["added_count"] self._num_timesteps_added_wrap = state["added_count_wrapped"] self._eviction_started = state["eviction_started"] self._num_timesteps_sampled = state["sampled_count"] self._est_size_bytes = state["est_size_bytes"]
class PrioritizedReplayBuffer(ReplayBuffer): @DeveloperAPI def __init__(self, capacity: int = 10000, alpha: float = 1.0, size: Optional[int] = DEPRECATED_VALUE): """Initializes a PrioritizedReplayBuffer instance. Args: capacity (int): Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. alpha (float): How much prioritization is used (0.0=no prioritization, 1.0=full prioritization). """ super(PrioritizedReplayBuffer, self).__init__(capacity, size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < self.capacity: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI @override(ReplayBuffer) def add(self, item: SampleBatchType, weight: float) -> None: idx = self._next_idx super(PrioritizedReplayBuffer, self).add(item, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, num_items: int) -> List[int]: res = [] for _ in range(num_items): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @DeveloperAPI @override(ReplayBuffer) def sample(self, num_items: int, beta: float) -> SampleBatchType: """Sample a batch of experiences and return priority weights, indices. Args: num_items (int): Number of items to sample from this buffer. beta (float): To what degree to use importance weights (0 - no corrections, 1 - full correction). Returns: SampleBatchType: Concatenated batch of items including "weights" and "batch_indexes" fields denoting IS of each sampled transition and original idxes in buffer of sampled experiences. """ assert beta >= 0.0 idxes = self._sample_proportional(num_items) weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) count = self._storage[idx].count # If zero-padded, count will not be the actual batch size of the # data. if isinstance(self._storage[idx], SampleBatch) and \ self._storage[idx].zero_padded: actual_size = self._storage[idx].max_seq_len else: actual_size = count weights.extend([weight / max_weight] * actual_size) batch_indexes.extend([idx] * actual_size) self._num_timesteps_sampled += count batch = self._encode_sample(idxes) # Note: prioritization is not supported in lockstep replay mode. if isinstance(batch, SampleBatch): batch["weights"] = np.array(weights) batch["batch_indexes"] = np.array(batch_indexes) return batch @DeveloperAPI def update_priorities(self, idxes: List[int], priorities: List[float]) -> None: """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ # Making sure we don't pass in e.g. a torch tensor. assert isinstance(idxes, (list, np.ndarray)), \ "ERROR: `idxes` is not a list or np.ndarray, but " \ "{}!".format(type(idxes).__name__) assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) @DeveloperAPI @override(ReplayBuffer) def stats(self, debug: bool = False) -> Dict: parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent @DeveloperAPI @override(ReplayBuffer) def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: Dict[str, Any]: The serializable local state. """ # Get parent state. state = super().get_state() # Add prio weights. state.update({ "sum_segment_tree": self._it_sum.get_state(), "min_segment_tree": self._it_min.get_state(), "max_priority": self._max_priority, }) return state @DeveloperAPI @override(ReplayBuffer) def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state (Dict[str, Any]): The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ super().set_state(state) self._it_sum.set_state(state["sum_segment_tree"]) self._it_min.set_state(state["min_segment_tree"]) self._max_priority = state["max_priority"]
class ReplayBuffer: def __init__( self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs ): """Initializes a ReplayBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. storage_unit: Either 'sequences' or 'timesteps'. Specifies how experiences are stored. **kwargs: Forward compatibility kwargs. """ if storage_unit == "timesteps": self._store_as_sequences = False elif storage_unit == "sequences": self._store_as_sequences = True else: raise ValueError("storage_unit must be either 'sequences' or 'timestamps'") # The actual storage (list of SampleBatches). self._storage = [] self.capacity = capacity # The next index to override in the buffer. self._next_idx = 0 self._hit_count = np.zeros(self.capacity) # Whether we have already hit our capacity (and have therefore # started to evict older samples). self._eviction_started = False # Number of (single) timesteps that have been added to the buffer # over its lifetime. Note that each added item (batch) may contain # more than one timestep. self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 # Number of (single) timesteps that have been sampled from the buffer # over its lifetime. self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self) -> int: """Returns the number of items currently stored in this buffer.""" return len(self._storage) @ExperimentalAPI def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences. Args: batch: SampleBatch to add to this buffer's storage. **kwargs: Forward compatibility kwargs. """ assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) # Update our timesteps counts. self._num_timesteps_added += batch.count self._num_timesteps_added_wrap += batch.count if self._next_idx >= len(self._storage): self._storage.append(batch) self._est_size_bytes += batch.size_bytes() else: self._storage[self._next_idx] = batch # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 @ExperimentalAPI def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]: """Samples a batch of size `num_items` from this buffer. If less than `num_items` records are in this buffer, some samples in the results may be repeated to fulfil the batch size (`num_items`) request. Args: num_items: Number of items to sample from this buffer. **kwargs: Forward compatibility kwargs. Returns: Concatenated batch of items. None if buffer is empty. """ # If we don't have any samples yet in this buffer, return None. if len(self) == 0: return None idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] sample = self._encode_sample(idxes) # Update our timesteps counters. self._num_timesteps_sampled += len(sample) return sample @ExperimentalAPI def stats(self, debug: bool = False) -> dict: """Returns the stats of this buffer. Args: debug: If True, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ data = { "added_count": self._num_timesteps_added, "added_count_wrapped": self._num_timesteps_added_wrap, "eviction_started": self._eviction_started, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data @ExperimentalAPI def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ state = {"_storage": self._storage, "_next_idx": self._next_idx} state.update(self.stats(debug=False)) return state @ExperimentalAPI def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ # The actual storage. self._storage = state["_storage"] self._next_idx = state["_next_idx"] # Stats and counts. self._num_timesteps_added = state["added_count"] self._num_timesteps_added_wrap = state["added_count_wrapped"] self._eviction_started = state["eviction_started"] self._num_timesteps_sampled = state["sampled_count"] self._est_size_bytes = state["est_size_bytes"] def _encode_sample(self, idxes: List[int]) -> SampleBatchType: out = SampleBatch.concat_samples([self._storage[i] for i in idxes]) out.decompress_if_needed() return out def get_host(self) -> str: """Returns the computer's network name. Returns: The computer's networks name or an empty string, if the network name could not be determined. """ return platform.node()