Exemplo n.º 1
0
    def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
        """Add a SampleBatch of experiences to self._storage.

        An item consists of either one or more timesteps, a sequence or an
        episode. Differs from add() in that it does not consider the storage
        unit or type of batch and simply stores it.

        Args:
            item: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        if self._next_idx >= len(self._storage):
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            item_to_be_removed = self._storage[self._next_idx]
            self._est_size_bytes -= item_to_be_removed.size_bytes()
            self._storage[self._next_idx] = item
            self._est_size_bytes += item.size_bytes()

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1
Exemplo n.º 2
0
def warn_replay_buffer_capacity(*, item: SampleBatchType,
                                capacity: int) -> None:
    """Warn if the configured replay buffer capacity is too large for machine's memory.

    Args:
        item: A (example) item that's supposed to be added to the buffer.
            This is used to compute the overall memory footprint estimate for the
            buffer.
        capacity: The capacity value of the buffer. This is interpreted as the
            number of items (such as given `item`) that will eventually be stored in
            the buffer.

    Raises:
        ValueError: If computed memory footprint for the buffer exceeds the machine's
            RAM.
    """
    if log_once("warn_replay_buffer_capacity"):
        item_size = item.size_bytes()
        psutil_mem = psutil.virtual_memory()
        total_gb = psutil_mem.total / 1e9
        mem_size = capacity * item_size / 1e9
        msg = ("Estimated max memory usage for replay buffer is {} GB "
               "({} batches of size {}, {} bytes each), "
               "available system memory is {} GB".format(
                   mem_size, capacity, item.count, item_size, total_gb))
        if mem_size > total_gb:
            raise ValueError(msg)
        elif mem_size > 0.2 * total_gb:
            logger.warning(msg)
        else:
            logger.info(msg)
Exemplo n.º 3
0
    def add(self, item: SampleBatchType):
        warn_replay_buffer_size(item=item,
                                num_items=self._maxsize / item.count)
        assert item.count == 1, item
        if self._next_idx is not None:

            self._num_timesteps_added += item.count

            if self._next_idx >= len(self._storage):
                self._storage.append(item)
                self._est_size_bytes += item.size_bytes()
            else:
                self._storage[self._next_idx] = item

            if self._num_timesteps_added >= self._maxsize:
                self._eviction_started = True
                self._sample_reservior_buffer_next_index()
            else:
                self._next_idx += 1

            if self._eviction_started:
                self._evicted_hit_stats.push(self._hit_count[self._next_idx])
                self._hit_count[self._next_idx] = 0
        else:
            assert self._add_calls >= self._maxsize
            self._sample_reservior_buffer_next_index()
        self._add_calls += 1
Exemplo n.º 4
0
    def add(self, item: SampleBatchType, weight: float) -> None:
        warn_replay_buffer_size(item=item,
                                num_items=self._maxsize / item.count)
        assert item.count > 0, item

        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        if self._next_idx >= len(self._storage):
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            self._storage[self._next_idx] = item

        # Wrap around storage as a circular buffer once we hit maxsize.
        if self._num_timesteps_added_wrap >= self._maxsize:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0
Exemplo n.º 5
0
    def add(self, item: SampleBatchType, weight: float) -> None:
        """Add a batch of experiences.

        Args:
            item: SampleBatch to add to this buffer's storage.
            weight: The weight of the added sample used in subsequent
                sampling steps. Only relevant if this ReplayBuffer is
                a PrioritizedReplayBuffer.
        """
        assert item.count > 0, item
        warn_replay_capacity(item=item, num_items=self.capacity / item.count)

        # Update our timesteps counts.
        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        if self._next_idx >= len(self._storage):
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            self._storage[self._next_idx] = item

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0
Exemplo n.º 6
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
            **kwargs: Forward compatibility kwargs.
        """
        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._next_idx >= len(self._storage):
            self._storage.append(batch)
            self._est_size_bytes += batch.size_bytes()
        else:
            self._storage[self._next_idx] = batch

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0
Exemplo n.º 7
0
    def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
        """Add a SampleBatch of experiences to self._storage.

        An item consists of either one or more timesteps, a sequence or an
        episode. Differs from add() in that it does not consider the storage
        unit or type of batch and simply stores it.

        Args:
            item: The batch to be added.
            ``**kwargs``: Forward compatibility kwargs.
        """
        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        # Update add counts.
        self._num_add_calls += 1
        # Update our timesteps counts.

        if self._num_timesteps_added < self.capacity:
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            # Eviction of older samples has already started (buffer is "full")
            self._eviction_started = True
            idx = random.randint(0, self._num_add_calls - 1)
            if idx < len(self._storage):
                self._num_evicted += 1
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                # This is a bit of a hack: ReplayBuffer always inserts at
                # self._next_idx
                self._next_idx = idx
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0

                item_to_be_removed = self._storage[idx]
                self._est_size_bytes -= item_to_be_removed.size_bytes()
                self._storage[idx] = item
                self._est_size_bytes += item.size_bytes()

                assert item.count > 0, item
                warn_replay_capacity(item=item,
                                     num_items=self.capacity / item.count)
Exemplo n.º 8
0
    def add(self, item: SampleBatchType, weight: float):
        assert item.count > 0, item
        self._num_added += 1

        if self._next_idx >= len(self._storage):
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            self._storage[self._next_idx] = item
        if self._next_idx + 1 >= self._maxsize:
            self._eviction_started = True
        self._next_idx = (self._next_idx + 1) % self._maxsize
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0
Exemplo n.º 9
0
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
    """Warn if the configured replay buffer capacity is too large."""
    if log_once("replay_capacity"):
        item_size = item.size_bytes()
        psutil_mem = psutil.virtual_memory()
        total_gb = psutil_mem.total / 1e9
        mem_size = num_items * item_size / 1e9
        msg = ("Estimated max memory usage for replay buffer is {} GB "
               "({} batches of size {}, {} bytes each), "
               "available system memory is {} GB".format(
                   mem_size, num_items, item.count, item_size, total_gb))
        if mem_size > total_gb:
            raise ValueError(msg)
        elif mem_size > 0.2 * total_gb:
            logger.warning(msg)
        else:
            logger.info(msg)
Exemplo n.º 10
0
    def add(self, batch: SampleBatchType, weight: float) -> None:
        """Add a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
            weight: The weight of the added sample used in subsequent sampling
                steps.
        """
        idx = self._next_idx

        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._next_idx >= len(self._storage):
            self._storage.append(batch)
            self._est_size_bytes += batch.size_bytes()
        else:
            self._storage[self._next_idx] = batch

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0

        if weight is None:
            weight = self._max_priority
        self._it_sum[idx] = weight**self._alpha
        self._it_min[idx] = weight**self._alpha