def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # We need to split batches into timesteps, sequences or episodes
        # here already to properly keep track of self.last_added_batches
        # underlying buffers should not split up the batch any further
        with self.add_batch_timer:
            if self._storage_unit == StorageUnit.TIMESTEPS:
                for policy_id, sample_batch in batch.policy_batches.items():
                    if self.replay_sequence_length == 1:
                        timeslices = sample_batch.timeslices(1)
                    else:
                        timeslices = timeslice_along_seq_lens_with_overlap(
                            sample_batch=sample_batch,
                            zero_pad_max_seq_len=self.replay_sequence_length,
                            pre_overlap=self.replay_burn_in,
                            zero_init_states=self.replay_zero_init_states,
                        )
                    for time_slice in timeslices:
                        self.replay_buffers[policy_id].add(
                            time_slice, **kwargs)
                        self.last_added_batches[policy_id].append(time_slice)
            elif self._storage_unit == StorageUnit.SEQUENCES:
                timestep_count = 0
                for policy_id, sample_batch in batch.policy_batches.items():
                    for seq_len in sample_batch.get(SampleBatch.SEQ_LENS):
                        start_seq = timestep_count
                        end_seq = timestep_count + seq_len
                        self.replay_buffers[policy_id].add(
                            sample_batch[start_seq:end_seq], **kwargs)
                        self.last_added_batches[policy_id].append(
                            sample_batch[start_seq:end_seq])
                        timestep_count = end_seq
            elif self._storage_unit == StorageUnit.EPISODES:
                for policy_id, sample_batch in batch.policy_batches.items():
                    for eps in sample_batch.split_by_episode():
                        # Only add full episodes to the buffer
                        if (eps.get(SampleBatch.T)[0] == 0 and eps.get(
                                SampleBatch.DONES)[-1] == True  # noqa E712
                            ):
                            self.replay_buffers[policy_id].add(eps, **kwargs)
                            self.last_added_batches[policy_id].append(eps)
                        else:
                            if log_once("only_full_episodes"):
                                logger.info(
                                    "This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")

        self._num_added += batch.count
    def sample(self,
               num_items: int,
               policy_id: PolicyID = DEFAULT_POLICY_ID,
               **kwargs) -> Optional[SampleBatchType]:
        """Samples a batch of size `num_items` from a specified buffer.

        Concatenates old samples to new ones according to
        self.replay_ratio. If not enough new samples are available, mixes in
        less old samples to retain self.replay_ratio on average. Returns
        an empty batch if there are no items in the buffer.

        Args:
            num_items: Number of items to sample fromM this buffer.
            policy_id: ID of the policy that produced the experiences to be
            sampled.
            **kwargs: Forward compatibility kwargs.

        Returns:
            Concatenated MultiAgentBatch of items.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        def mix_batches(_policy_id):
            """Mixes old with new samples.

            Tries to mix according to self.replay_ratio on average.
            If not enough new samples are available, mixes in less old samples
            to retain self.replay_ratio on average.
            """
            def round_up_or_down(value, ratio):
                """Returns an integer averaging to value*ratio."""
                product = value * ratio
                ceil_prob = product % 1
                if random.uniform(0, 1) < ceil_prob:
                    return int(np.ceil(product))
                else:
                    return int(np.floor(product))

            max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio)
            # if num_samples * self.replay_ratio is not round,
            # we need one more sample with a probability of
            # (num_items*self.replay_ratio) % 1

            _buffer = self.replay_buffers[_policy_id]
            output_batches = self.last_added_batches[_policy_id][:max_num_new]
            self.last_added_batches[_policy_id] = self.last_added_batches[
                _policy_id][max_num_new:]

            # No replay desired
            if self.replay_ratio == 0.0:
                return SampleBatch.concat_samples(output_batches)
            # Only replay desired
            elif self.replay_ratio == 1.0:
                return _buffer.sample(num_items, **kwargs)

            num_new = len(output_batches)

            if np.isclose(num_new, num_items * (1 - self.replay_ratio)):
                # The optimal case, we can mix in a round number of old
                # samples on average
                num_old = num_items - max_num_new
            else:
                # We never want to return more elements than num_items
                num_old = min(
                    num_items - max_num_new,
                    round_up_or_down(
                        num_new, self.replay_ratio / (1 - self.replay_ratio)),
                )

            output_batches.append(_buffer.sample(num_old, **kwargs))
            # Depending on the implementation of underlying buffers, samples
            # might be SampleBatches
            output_batches = [
                batch.as_multi_agent() for batch in output_batches
            ]
            return MultiAgentBatch.concat_samples(output_batches)

        def check_buffer_is_ready(_policy_id):
            if ((len(self.replay_buffers[policy_id]) == 0)
                    and self.replay_ratio > 0.0) or (
                        len(self.last_added_batches[_policy_id]) == 0
                        and self.replay_ratio < 1.0):
                return False
            return True

        with self.replay_timer:
            samples = []

            if self.replay_mode == ReplayMode.LOCKSTEP:
                assert (
                    policy_id is None
                ), "`policy_id` specifier not allowed in `lockstep` mode!"
                if check_buffer_is_ready(_ALL_POLICIES):
                    samples.append(mix_batches(_ALL_POLICIES).as_multi_agent())
            elif policy_id is not None:
                if check_buffer_is_ready(policy_id):
                    samples.append(mix_batches(policy_id).as_multi_agent())
            else:
                for policy_id, replay_buffer in self.replay_buffers.items():
                    if check_buffer_is_ready(policy_id):
                        samples.append(mix_batches(policy_id).as_multi_agent())

            return MultiAgentBatch.concat_samples(samples)
예제 #3
0
    def _add_to_underlying_buffer(self, policy_id: PolicyID,
                                  batch: SampleBatchType, **kwargs) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            ``**kwargs``: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self.storage_unit is StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
        elif self.storage_unit is StorageUnit.SEQUENCES:
            timeslices = timeslice_along_seq_lens_with_overlap(
                sample_batch=batch,
                seq_lens=batch.get(SampleBatch.SEQ_LENS)
                if self.replay_sequence_override else None,
                zero_pad_max_seq_len=self.replay_sequence_length,
                pre_overlap=self.replay_burn_in,
                zero_init_states=self.replay_zero_init_states,
            )
        elif self.storage_unit == StorageUnit.EPISODES:
            timeslices = []
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    timeslices.append(eps)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")
        elif self.storage_unit == StorageUnit.FRAGMENTS:
            timeslices = [batch]
        else:
            raise ValueError("Unknown `storage_unit={}`".format(
                self.storage_unit))

        for slice in timeslices:
            # If SampleBatch has prio-replay weights, average
            # over these to use as a weight for the entire
            # sequence.
            if self.replay_mode is ReplayMode.INDEPENDENT:
                if "weights" in slice and len(slice["weights"]):
                    weight = np.mean(slice["weights"])
                else:
                    weight = None

                if "weight" in kwargs and weight is not None:
                    if log_once("overwrite_weight"):
                        logger.warning("Adding batches with column "
                                       "`weights` to this buffer while "
                                       "providing weights as a call argument "
                                       "to the add method results in the "
                                       "column being overwritten.")

                kwargs = {"weight": weight, **kwargs}
            else:
                if "weight" in kwargs:
                    if log_once("lockstep_no_weight_allowed"):
                        logger.warning("Settings weights for batches in "
                                       "lockstep mode is not allowed."
                                       "Weights are being ignored.")

                kwargs = {**kwargs, "weight": None}
            self.replay_buffers[policy_id].add(slice, **kwargs)