コード例 #1
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
        """
        # Update add counts.
        self._num_add_calls += 1
        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._num_timesteps_added < self.capacity:
            ReplayBuffer.add(self, batch)
        else:
            # Eviction of older samples has already started (buffer is "full")
            self._eviction_started = True
            idx = random.randint(0, self._num_add_calls - 1)
            if idx < self.capacity:
                self._num_evicted += 1
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                self._storage[idx] = batch

                assert batch.count > 0, batch
                warn_replay_capacity(item=batch,
                                     num_items=self.capacity / batch.count)
コード例 #2
0
ファイル: replay_buffer.py プロジェクト: wuisawesome/ray
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
            **kwargs: Forward compatibility kwargs.
        """
        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._next_idx >= len(self._storage):
            self._storage.append(batch)
            self._est_size_bytes += batch.size_bytes()
        else:
            self._storage[self._next_idx] = batch

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0
コード例 #3
0
ファイル: replay_ops.py プロジェクト: miqdigital/ray
 def add_batch(self, sample_batch: SampleBatchType) -> None:
     warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
     if self.num_slots > 0:
         if len(self.replay_batches) < self.num_slots:
             self.replay_batches.append(sample_batch)
         else:
             self.replay_batches[self.replay_index] = sample_batch
             self.replay_index += 1
             self.replay_index %= self.num_slots
コード例 #4
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences to this buffer.

        Also splits experiences into chunks of timesteps, sequences
        or episodes, depending on self._storage_unit. Calls
        self._add_single_batch.

        Args:
            batch: Batch to add to this buffer's storage.
            **kwargs: Forward compatibility kwargs.
        """
        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        if (type(batch) == MultiAgentBatch
                and self._storage_unit != StorageUnit.TIMESTEPS):
            raise ValueError("Can not add MultiAgentBatch to ReplayBuffer "
                             "with storage_unit {}"
                             "".format(str(self._storage_unit)))

        if self._storage_unit == StorageUnit.TIMESTEPS:
            self._add_single_batch(batch, **kwargs)

        elif self._storage_unit == StorageUnit.SEQUENCES:
            timestep_count = 0
            for seq_len in batch.get(SampleBatch.SEQ_LENS):
                start_seq = timestep_count
                end_seq = timestep_count + seq_len
                self._add_single_batch(batch[start_seq:end_seq], **kwargs)
                timestep_count = end_seq

        elif self._storage_unit == StorageUnit.EPISODES:
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    self._add_single_batch(eps, **kwargs)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")
コード例 #5
0
    def add(self, batch: SampleBatchType, weight: float) -> None:
        """Add a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
            weight: The weight of the added sample used in subsequent sampling
                steps.
        """
        idx = self._next_idx

        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._next_idx >= len(self._storage):
            self._storage.append(batch)
            self._est_size_bytes += batch.size_bytes()
        else:
            self._storage[self._next_idx] = batch

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0

        if weight is None:
            weight = self._max_priority
        self._it_sum[idx] = weight**self._alpha
        self._it_min[idx] = weight**self._alpha
コード例 #6
0
ファイル: reservoir_buffer.py プロジェクト: patrickstuedi/ray
    def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
        """Add a SampleBatch of experiences to self._storage.

        An item consists of either one or more timesteps, a sequence or an
        episode. Differs from add() in that it does not consider the storage
        unit or type of batch and simply stores it.

        Args:
            item: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        # Update add counts.
        self._num_add_calls += 1
        # Update our timesteps counts.

        if self._num_timesteps_added < self.capacity:
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            # Eviction of older samples has already started (buffer is "full")
            self._eviction_started = True
            idx = random.randint(0, self._num_add_calls - 1)
            if idx < len(self._storage):
                self._num_evicted += 1
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                # This is a bit of a hack: ReplayBuffer always inserts at
                # self._next_idx
                self._next_idx = idx
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                self._storage[idx] = item

                assert item.count > 0, item
                warn_replay_capacity(item=item,
                                     num_items=self.capacity / item.count)