示例#1
0
    def add_batch(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            if self.replay_mode == ReplayMode.LOCKSTEP:
                # Lockstep mode: Store under _ALL_POLICIES key (we will always
                # only sample from all policies at the same time).
                # This means storing a MultiAgentBatch to the underlying buffer
                self.replay_buffers[_ALL_POLICIES].add_batch(batch)
                self.last_added_batches[_ALL_POLICIES].append(batch)
            else:
                # Store independent SampleBatches
                for policy_id, sample_batch in batch.policy_batches.items():
                    self.replay_buffers[policy_id].add_batch(sample_batch)
                    self.last_added_batches[policy_id].append(sample_batch)

        self.num_added += batch.count
示例#2
0
    def add(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            # Lockstep mode: Store under _ALL_POLICIES key (we will always
            # only sample from all policies at the same time).
            if self.replay_mode == "lockstep":
                # Note that prioritization is not supported in this mode.
                for s in batch.timeslices(self.replay_sequence_length):
                    self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
                    self.last_added_batches[_ALL_POLICIES].append(s)
            else:
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_policy_buffer(policy_id, sample_batch)
                    self.last_added_batches[policy_id].append(sample_batch)
        self._num_added += batch.count
示例#3
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch : The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        if batch is None:
            if log_once("empty_batch_added_to_buffer"):
                logger.info(
                    "A batch that is `None` was added to {}. This can be "
                    "normal at the beginning of execution but might "
                    "indicate an issue.".format(type(self).__name__))
            return
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            if self.replay_mode == ReplayMode.LOCKSTEP:
                # Lockstep mode: Store under _ALL_POLICIES key (we will always
                # only sample from all policies at the same time).
                # This means storing a MultiAgentBatch to the underlying buffer
                self._add_to_underlying_buffer(_ALL_POLICIES, batch, **kwargs)
            else:
                # Store independent SampleBatches
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_underlying_buffer(policy_id, sample_batch,
                                                   **kwargs)
        self._num_added += batch.count
示例#4
0
 def add_batch(self, batch: SampleBatchType) -> None:
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         # Lockstep mode: Store under _ALL_POLICIES key (we will always
         # only sample from all policies at the same time).
         if self.replay_mode == "lockstep":
             # Note that prioritization is not supported in this mode.
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
         else:
             for policy_id, sample_batch in batch.policy_batches.items():
                 if self.replay_sequence_length == 1:
                     timeslices = sample_batch.timeslices(1)
                 else:
                     timeslices = timeslice_along_seq_lens_with_overlap(
                         sample_batch=sample_batch,
                         zero_pad_max_seq_len=self.replay_sequence_length,
                         pre_overlap=self.replay_burn_in,
                         zero_init_states=self.replay_zero_init_states,
                     )
                 for time_slice in timeslices:
                     # If SampleBatch has prio-replay weights, average
                     # over these to use as a weight for the entire
                     # sequence.
                     if "weights" in time_slice:
                         weight = np.mean(time_slice["weights"])
                     else:
                         weight = None
                     self.replay_buffers[policy_id].add(time_slice,
                                                        weight=weight)
     self.num_added += batch.count
示例#5
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch : The batch to be added.
            ``**kwargs``: Forward compatibility kwargs.
        """
        if batch is None:
            if log_once("empty_batch_added_to_buffer"):
                logger.info(
                    "A batch that is `None` was added to {}. This can be "
                    "normal at the beginning of execution but might "
                    "indicate an issue.".format(type(self).__name__))
            return
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            pids_and_batches = self._maybe_split_into_policy_batches(batch)
            for policy_id, sample_batch in pids_and_batches.items():
                self._add_to_underlying_buffer(policy_id, sample_batch,
                                               **kwargs)

        self._num_added += batch.count
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch : The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            if self.replay_mode == ReplayMode.LOCKSTEP:
                # Lockstep mode: Store under _ALL_POLICIES key (we will always
                # only sample from all policies at the same time).
                # This means storing a MultiAgentBatch to the underlying buffer
                self._add_to_underlying_buffer(_ALL_POLICIES, batch, **kwargs)
            else:
                # Store independent SampleBatches
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_underlying_buffer(policy_id, sample_batch,
                                                   **kwargs)
        self._num_added += batch.count
示例#7
0
    def add_batch(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            for policy_id, sample_batch in batch.policy_batches.items():
                self.replay_buffers[policy_id].add_batch(sample_batch)
                self.last_added_batches[policy_id].append(sample_batch)
        self.num_added += batch.count
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # We need to split batches into timesteps, sequences or episodes
        # here already to properly keep track of self.last_added_batches
        # underlying buffers should not split up the batch any further
        with self.add_batch_timer:
            if self._storage_unit == StorageUnit.TIMESTEPS:
                for policy_id, sample_batch in batch.policy_batches.items():
                    if self.replay_sequence_length == 1:
                        timeslices = sample_batch.timeslices(1)
                    else:
                        timeslices = timeslice_along_seq_lens_with_overlap(
                            sample_batch=sample_batch,
                            zero_pad_max_seq_len=self.replay_sequence_length,
                            pre_overlap=self.replay_burn_in,
                            zero_init_states=self.replay_zero_init_states,
                        )
                    for time_slice in timeslices:
                        self.replay_buffers[policy_id].add(
                            time_slice, **kwargs)
                        self.last_added_batches[policy_id].append(time_slice)
            elif self._storage_unit == StorageUnit.SEQUENCES:
                timestep_count = 0
                for policy_id, sample_batch in batch.policy_batches.items():
                    for seq_len in sample_batch.get(SampleBatch.SEQ_LENS):
                        start_seq = timestep_count
                        end_seq = timestep_count + seq_len
                        self.replay_buffers[policy_id].add(
                            sample_batch[start_seq:end_seq], **kwargs)
                        self.last_added_batches[policy_id].append(
                            sample_batch[start_seq:end_seq])
                        timestep_count = end_seq
            elif self._storage_unit == StorageUnit.EPISODES:
                for policy_id, sample_batch in batch.policy_batches.items():
                    for eps in sample_batch.split_by_episode():
                        # Only add full episodes to the buffer
                        if (eps.get(SampleBatch.T)[0] == 0 and eps.get(
                                SampleBatch.DONES)[-1] == True  # noqa E712
                            ):
                            self.replay_buffers[policy_id].add(eps, **kwargs)
                            self.last_added_batches[policy_id].append(eps)
                        else:
                            if log_once("only_full_episodes"):
                                logger.info(
                                    "This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")

        self._num_added += batch.count