def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch to the appropriate policy's replay buffer. Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if it is not a MultiAgentBatch. Subsequently, adds the individual policy batches to the storage. Args: batch: The batch to be added. **kwargs: Forward compatibility kwargs. """ # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multi-agent. batch = batch.as_multi_agent() kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) # We need to split batches into timesteps, sequences or episodes # here already to properly keep track of self.last_added_batches # underlying buffers should not split up the batch any further with self.add_batch_timer: if self._storage_unit == StorageUnit.TIMESTEPS: for policy_id, sample_batch in batch.policy_batches.items(): if self.replay_sequence_length == 1: timeslices = sample_batch.timeslices(1) else: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=sample_batch, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) for time_slice in timeslices: self.replay_buffers[policy_id].add( time_slice, **kwargs) self.last_added_batches[policy_id].append(time_slice) elif self._storage_unit == StorageUnit.SEQUENCES: timestep_count = 0 for policy_id, sample_batch in batch.policy_batches.items(): for seq_len in sample_batch.get(SampleBatch.SEQ_LENS): start_seq = timestep_count end_seq = timestep_count + seq_len self.replay_buffers[policy_id].add( sample_batch[start_seq:end_seq], **kwargs) self.last_added_batches[policy_id].append( sample_batch[start_seq:end_seq]) timestep_count = end_seq elif self._storage_unit == StorageUnit.EPISODES: for policy_id, sample_batch in batch.policy_batches.items(): for eps in sample_batch.split_by_episode(): # Only add full episodes to the buffer if (eps.get(SampleBatch.T)[0] == 0 and eps.get( SampleBatch.DONES)[-1] == True # noqa E712 ): self.replay_buffers[policy_id].add(eps, **kwargs) self.last_added_batches[policy_id].append(eps) else: if log_once("only_full_episodes"): logger.info( "This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") self._num_added += batch.count
def sample(self, num_items: int, policy_id: PolicyID = DEFAULT_POLICY_ID, **kwargs) -> Optional[SampleBatchType]: """Samples a batch of size `num_items` from a specified buffer. Concatenates old samples to new ones according to self.replay_ratio. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. Returns an empty batch if there are no items in the buffer. Args: num_items: Number of items to sample fromM this buffer. policy_id: ID of the policy that produced the experiences to be sampled. **kwargs: Forward compatibility kwargs. Returns: Concatenated MultiAgentBatch of items. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) def mix_batches(_policy_id): """Mixes old with new samples. Tries to mix according to self.replay_ratio on average. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. """ def round_up_or_down(value, ratio): """Returns an integer averaging to value*ratio.""" product = value * ratio ceil_prob = product % 1 if random.uniform(0, 1) < ceil_prob: return int(np.ceil(product)) else: return int(np.floor(product)) max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio) # if num_samples * self.replay_ratio is not round, # we need one more sample with a probability of # (num_items*self.replay_ratio) % 1 _buffer = self.replay_buffers[_policy_id] output_batches = self.last_added_batches[_policy_id][:max_num_new] self.last_added_batches[_policy_id] = self.last_added_batches[ _policy_id][max_num_new:] # No replay desired if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, **kwargs) num_new = len(output_batches) if np.isclose(num_new, num_items * (1 - self.replay_ratio)): # The optimal case, we can mix in a round number of old # samples on average num_old = num_items - max_num_new else: # We never want to return more elements than num_items num_old = min( num_items - max_num_new, round_up_or_down( num_new, self.replay_ratio / (1 - self.replay_ratio)), ) output_batches.append(_buffer.sample(num_old, **kwargs)) # Depending on the implementation of underlying buffers, samples # might be SampleBatches output_batches = [ batch.as_multi_agent() for batch in output_batches ] return MultiAgentBatch.concat_samples(output_batches) def check_buffer_is_ready(_policy_id): if ((len(self.replay_buffers[policy_id]) == 0) and self.replay_ratio > 0.0) or ( len(self.last_added_batches[_policy_id]) == 0 and self.replay_ratio < 1.0): return False return True with self.replay_timer: samples = [] if self.replay_mode == ReplayMode.LOCKSTEP: assert ( policy_id is None ), "`policy_id` specifier not allowed in `lockstep` mode!" if check_buffer_is_ready(_ALL_POLICIES): samples.append(mix_batches(_ALL_POLICIES).as_multi_agent()) elif policy_id is not None: if check_buffer_is_ready(policy_id): samples.append(mix_batches(policy_id).as_multi_agent()) else: for policy_id, replay_buffer in self.replay_buffers.items(): if check_buffer_is_ready(policy_id): samples.append(mix_batches(policy_id).as_multi_agent()) return MultiAgentBatch.concat_samples(samples)
def _add_to_underlying_buffer(self, policy_id: PolicyID, batch: SampleBatchType, **kwargs) -> None: """Add a batch of experiences to the underlying buffer of a policy. If the storage unit is `timesteps`, cut the batch into timeslices before adding them to the appropriate buffer. Otherwise, let the underlying buffer decide how slice batches. Args: policy_id: ID of the policy that corresponds to the underlying buffer batch: SampleBatch to add to the underlying buffer ``**kwargs``: Forward compatibility kwargs. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) # For the storage unit `timesteps`, the underlying buffer will # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self.storage_unit is StorageUnit.TIMESTEPS: timeslices = batch.timeslices(1) elif self.storage_unit is StorageUnit.SEQUENCES: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=batch, seq_lens=batch.get(SampleBatch.SEQ_LENS) if self.replay_sequence_override else None, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) elif self.storage_unit == StorageUnit.EPISODES: timeslices = [] for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer timeslices.append(eps) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") elif self.storage_unit == StorageUnit.FRAGMENTS: timeslices = [batch] else: raise ValueError("Unknown `storage_unit={}`".format( self.storage_unit)) for slice in timeslices: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. if self.replay_mode is ReplayMode.INDEPENDENT: if "weights" in slice and len(slice["weights"]): weight = np.mean(slice["weights"]) else: weight = None if "weight" in kwargs and weight is not None: if log_once("overwrite_weight"): logger.warning("Adding batches with column " "`weights` to this buffer while " "providing weights as a call argument " "to the add method results in the " "column being overwritten.") kwargs = {"weight": weight, **kwargs} else: if "weight" in kwargs: if log_once("lockstep_no_weight_allowed"): logger.warning("Settings weights for batches in " "lockstep mode is not allowed." "Weights are being ignored.") kwargs = {**kwargs, "weight": None} self.replay_buffers[policy_id].add(slice, **kwargs)