Exemplo n.º 1
0
    def add(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            # Lockstep mode: Store under _ALL_POLICIES key (we will always
            # only sample from all policies at the same time).
            if self.replay_mode == "lockstep":
                # Note that prioritization is not supported in this mode.
                for s in batch.timeslices(self.replay_sequence_length):
                    self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
                    self.last_added_batches[_ALL_POLICIES].append(s)
            else:
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_policy_buffer(policy_id, sample_batch)
                    self.last_added_batches[policy_id].append(sample_batch)
        self._num_added += batch.count
Exemplo n.º 2
0
 def add_batch(self, batch: SampleBatchType) -> None:
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         # Lockstep mode: Store under _ALL_POLICIES key (we will always
         # only sample from all policies at the same time).
         if self.replay_mode == "lockstep":
             # Note that prioritization is not supported in this mode.
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
         else:
             for policy_id, sample_batch in batch.policy_batches.items():
                 if self.replay_sequence_length == 1:
                     timeslices = sample_batch.timeslices(1)
                 else:
                     timeslices = timeslice_along_seq_lens_with_overlap(
                         sample_batch=sample_batch,
                         zero_pad_max_seq_len=self.replay_sequence_length,
                         pre_overlap=self.replay_burn_in,
                         zero_init_states=self.replay_zero_init_states,
                     )
                 for time_slice in timeslices:
                     # If SampleBatch has prio-replay weights, average
                     # over these to use as a weight for the entire
                     # sequence.
                     if "weights" in time_slice:
                         weight = np.mean(time_slice["weights"])
                     else:
                         weight = None
                     self.replay_buffers[policy_id].add(time_slice,
                                                        weight=weight)
     self.num_added += batch.count
Exemplo n.º 3
0
    def _add_to_underlying_buffer(self, policy_id: PolicyID,
                                  batch: SampleBatchType, **kwargs) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            **kwargs: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self._storage_unit is StorageUnit.TIMESTEPS:
            if self.replay_sequence_length == 1:
                timeslices = batch.timeslices(1)
            else:
                timeslices = timeslice_along_seq_lens_with_overlap(
                    sample_batch=batch,
                    zero_pad_max_seq_len=self.replay_sequence_length,
                    pre_overlap=self.replay_burn_in,
                    zero_init_states=self.replay_zero_init_states,
                )
            for time_slice in timeslices:
                self.replay_buffers[policy_id].add(time_slice, **kwargs)
        else:
            self.replay_buffers[policy_id].add(batch, **kwargs)
    def _add_to_underlying_buffer(self, policy_id: PolicyID,
                                  batch: SampleBatchType, **kwargs) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            **kwargs: Forward compatibility kwargs.
        """
        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self._storage_unit is StorageUnit.TIMESTEPS:
            if self.replay_sequence_length == 1:
                timeslices = batch.timeslices(1)
            else:
                timeslices = timeslice_along_seq_lens_with_overlap(
                    sample_batch=batch,
                    zero_pad_max_seq_len=self.replay_sequence_length,
                    pre_overlap=self.replay_burn_in,
                    zero_init_states=self.replay_zero_init_states,
                )
            for time_slice in timeslices:
                # If SampleBatch has prio-replay weights, average
                # over these to use as a weight for the entire
                # sequence.
                if self.replay_mode is ReplayMode.INDEPENDENT:
                    if "weights" in time_slice and len(time_slice["weights"]):
                        weight = np.mean(time_slice["weights"])
                    else:
                        weight = None

                    if "weight" in kwargs and weight is not None:
                        if log_once("overwrite_weight"):
                            logger.warning(
                                "Adding batches with column "
                                "`weights` to this buffer while "
                                "providing weights as a call argument "
                                "to the add method results in the "
                                "column being overwritten.")

                    kwargs = {"weight": weight, **kwargs}
                else:
                    if "weight" in kwargs:
                        if log_once("lockstep_no_weight_allowed"):
                            logger.warning("Settings weights for batches in "
                                           "lockstep mode is not allowed."
                                           "Weights are being ignored.")

                    kwargs = {**kwargs, "weight": None}
                self.replay_buffers[policy_id].add(time_slice, **kwargs)
        else:
            self.replay_buffers[policy_id].add(batch, **kwargs)
Exemplo n.º 5
0
    def _add_to_underlying_buffer(
        self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
    ) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            ``**kwargs``: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self.storage_unit is StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
        elif self.storage_unit is StorageUnit.SEQUENCES:
            timeslices = timeslice_along_seq_lens_with_overlap(
                sample_batch=batch,
                seq_lens=batch.get(SampleBatch.SEQ_LENS)
                if self.replay_sequence_override
                else None,
                zero_pad_max_seq_len=self.replay_sequence_length,
                pre_overlap=self.replay_burn_in,
                zero_init_states=self.replay_zero_init_states,
            )
        elif self.storage_unit == StorageUnit.EPISODES:
            timeslices = []
            for eps in batch.split_by_episode():
                if (
                    eps.get(SampleBatch.T)[0] == 0
                    and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                ):
                    # Only add full episodes to the buffer
                    timeslices.append(eps)
                else:
                    if log_once("only_full_episodes"):
                        logger.info(
                            "This buffer uses episodes as a storage "
                            "unit and thus allows only full episodes "
                            "to be added to it. Some samples may be "
                            "dropped."
                        )
        elif self.storage_unit == StorageUnit.FRAGMENTS:
            timeslices = [batch]
        else:
            raise ValueError("Unknown `storage_unit={}`".format(self.storage_unit))

        for slice in timeslices:
            self.replay_buffers[policy_id].add(slice, **kwargs)
Exemplo n.º 6
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences to this buffer.

        Splits batch into chunks of timesteps, sequences or episodes, depending on
        `self._storage_unit`. Calls `self._add_single_batch` to add resulting slices
        to the buffer storage.

        Args:
            batch: Batch to add.
            ``**kwargs``: Forward compatibility kwargs.
        """
        if not batch.count > 0:
            return

        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        if self.storage_unit == StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
            for t in timeslices:
                self._add_single_batch(t, **kwargs)

        elif self.storage_unit == StorageUnit.SEQUENCES:
            timestep_count = 0
            for seq_len in batch.get(SampleBatch.SEQ_LENS):
                start_seq = timestep_count
                end_seq = timestep_count + seq_len
                self._add_single_batch(batch[start_seq:end_seq], **kwargs)
                timestep_count = end_seq

        elif self.storage_unit == StorageUnit.EPISODES:
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    self._add_single_batch(eps, **kwargs)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")

        elif self.storage_unit == StorageUnit.FRAGMENTS:
            self._add_single_batch(batch, **kwargs)
Exemplo n.º 7
0
 def _add_to_policy_buffer(self, policy_id: PolicyID,
                           batch: SampleBatchType) -> None:
     if self.replay_sequence_length == 1:
         timeslices = batch.timeslices(1)
     else:
         timeslices = timeslice_along_seq_lens_with_overlap(
             sample_batch=batch,
             zero_pad_max_seq_len=self.replay_sequence_length,
             pre_overlap=self.replay_burn_in,
             zero_init_states=self.replay_zero_init_states,
         )
     for time_slice in timeslices:
         # If SampleBatch has prio-replay weights, average
         # over these to use as a weight for the entire
         # sequence.
         if "weights" in time_slice and len(time_slice["weights"]):
             weight = np.mean(time_slice["weights"])
         else:
             weight = None
         self.replay_buffers[policy_id].add(time_slice, weight=weight)
Exemplo n.º 8
0
    def _add_to_underlying_buffer(self, policy_id: PolicyID,
                                  batch: SampleBatchType, **kwargs) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            ``**kwargs``: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self.storage_unit is StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
        elif self.storage_unit is StorageUnit.SEQUENCES:
            timeslices = timeslice_along_seq_lens_with_overlap(
                sample_batch=batch,
                seq_lens=batch.get(SampleBatch.SEQ_LENS)
                if self.replay_sequence_override else None,
                zero_pad_max_seq_len=self.replay_sequence_length,
                pre_overlap=self.replay_burn_in,
                zero_init_states=self.replay_zero_init_states,
            )
        elif self.storage_unit == StorageUnit.EPISODES:
            timeslices = []
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    timeslices.append(eps)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")
        elif self.storage_unit == StorageUnit.FRAGMENTS:
            timeslices = [batch]
        else:
            raise ValueError("Unknown `storage_unit={}`".format(
                self.storage_unit))

        for slice in timeslices:
            # If SampleBatch has prio-replay weights, average
            # over these to use as a weight for the entire
            # sequence.
            if self.replay_mode is ReplayMode.INDEPENDENT:
                if "weights" in slice and len(slice["weights"]):
                    weight = np.mean(slice["weights"])
                else:
                    weight = None

                if "weight" in kwargs and weight is not None:
                    if log_once("overwrite_weight"):
                        logger.warning("Adding batches with column "
                                       "`weights` to this buffer while "
                                       "providing weights as a call argument "
                                       "to the add method results in the "
                                       "column being overwritten.")

                kwargs = {"weight": weight, **kwargs}
            else:
                if "weight" in kwargs:
                    if log_once("lockstep_no_weight_allowed"):
                        logger.warning("Settings weights for batches in "
                                       "lockstep mode is not allowed."
                                       "Weights are being ignored.")

                kwargs = {**kwargs, "weight": None}
            self.replay_buffers[policy_id].add(slice, **kwargs)