def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a SampleBatch of experiences to self._storage. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The batch to be added. **kwargs: Forward compatibility kwargs. """ self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: item_to_be_removed = self._storage[self._next_idx] self._est_size_bytes -= item_to_be_removed.size_bytes() self._storage[self._next_idx] = item self._est_size_bytes += item.size_bytes() # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None: """Warn if the configured replay buffer capacity is too large for machine's memory. Args: item: A (example) item that's supposed to be added to the buffer. This is used to compute the overall memory footprint estimate for the buffer. capacity: The capacity value of the buffer. This is interpreted as the number of items (such as given `item`) that will eventually be stored in the buffer. Raises: ValueError: If computed memory footprint for the buffer exceeds the machine's RAM. """ if log_once("warn_replay_buffer_capacity"): item_size = item.size_bytes() psutil_mem = psutil.virtual_memory() total_gb = psutil_mem.total / 1e9 mem_size = capacity * item_size / 1e9 msg = ("Estimated max memory usage for replay buffer is {} GB " "({} batches of size {}, {} bytes each), " "available system memory is {} GB".format( mem_size, capacity, item.count, item_size, total_gb)) if mem_size > total_gb: raise ValueError(msg) elif mem_size > 0.2 * total_gb: logger.warning(msg) else: logger.info(msg)
def add(self, item: SampleBatchType): warn_replay_buffer_size(item=item, num_items=self._maxsize / item.count) assert item.count == 1, item if self._next_idx is not None: self._num_timesteps_added += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item if self._num_timesteps_added >= self._maxsize: self._eviction_started = True self._sample_reservior_buffer_next_index() else: self._next_idx += 1 if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 else: assert self._add_calls >= self._maxsize self._sample_reservior_buffer_next_index() self._add_calls += 1
def add(self, item: SampleBatchType, weight: float) -> None: warn_replay_buffer_size(item=item, num_items=self._maxsize / item.count) assert item.count > 0, item self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Wrap around storage as a circular buffer once we hit maxsize. if self._num_timesteps_added_wrap >= self._maxsize: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0
def add(self, item: SampleBatchType, weight: float) -> None: """Add a batch of experiences. Args: item: SampleBatch to add to this buffer's storage. weight: The weight of the added sample used in subsequent sampling steps. Only relevant if this ReplayBuffer is a PrioritizedReplayBuffer. """ assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count) # Update our timesteps counts. self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0
def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences. Args: batch: SampleBatch to add to this buffer's storage. **kwargs: Forward compatibility kwargs. """ assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) # Update our timesteps counts. self._num_timesteps_added += batch.count self._num_timesteps_added_wrap += batch.count if self._next_idx >= len(self._storage): self._storage.append(batch) self._est_size_bytes += batch.size_bytes() else: self._storage[self._next_idx] = batch # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a SampleBatch of experiences to self._storage. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The batch to be added. ``**kwargs``: Forward compatibility kwargs. """ self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count # Update add counts. self._num_add_calls += 1 # Update our timesteps counts. if self._num_timesteps_added < self.capacity: self._storage.append(item) self._est_size_bytes += item.size_bytes() else: # Eviction of older samples has already started (buffer is "full") self._eviction_started = True idx = random.randint(0, self._num_add_calls - 1) if idx < len(self._storage): self._num_evicted += 1 self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 # This is a bit of a hack: ReplayBuffer always inserts at # self._next_idx self._next_idx = idx self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 item_to_be_removed = self._storage[idx] self._est_size_bytes -= item_to_be_removed.size_bytes() self._storage[idx] = item self._est_size_bytes += item.size_bytes() assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count)
def add(self, item: SampleBatchType, weight: float): assert item.count > 0, item self._num_added += 1 if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item if self._next_idx + 1 >= self._maxsize: self._eviction_started = True self._next_idx = (self._next_idx + 1) % self._maxsize if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None: """Warn if the configured replay buffer capacity is too large.""" if log_once("replay_capacity"): item_size = item.size_bytes() psutil_mem = psutil.virtual_memory() total_gb = psutil_mem.total / 1e9 mem_size = num_items * item_size / 1e9 msg = ("Estimated max memory usage for replay buffer is {} GB " "({} batches of size {}, {} bytes each), " "available system memory is {} GB".format( mem_size, num_items, item.count, item_size, total_gb)) if mem_size > total_gb: raise ValueError(msg) elif mem_size > 0.2 * total_gb: logger.warning(msg) else: logger.info(msg)
def add(self, batch: SampleBatchType, weight: float) -> None: """Add a batch of experiences. Args: batch: SampleBatch to add to this buffer's storage. weight: The weight of the added sample used in subsequent sampling steps. """ idx = self._next_idx assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) # Update our timesteps counts. self._num_timesteps_added += batch.count self._num_timesteps_added_wrap += batch.count if self._next_idx >= len(self._storage): self._storage.append(batch) self._est_size_bytes += batch.size_bytes() else: self._storage[self._next_idx] = batch # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha