def action_log_likelihood(self, batch: SampleBatchType) -> TensorType: """Returns log likelihood for actions in given batch for policy. Computes likelihoods by passing the observations through the current policy's `compute_log_likelihoods()` method Args: batch: The SampleBatch or MultiAgentBatch to calculate action log likelihoods from. This batch/batches must contain OBS and ACTIONS keys. Returns: The probabilities of the actions in the batch, given the observations and the policy. """ num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): num_state_inputs += 1 state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] log_likelihoods: TensorType = self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], obs_batch=batch[SampleBatch.OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), actions_normalized=True, ) log_likelihoods = convert_to_numpy(log_likelihoods) return log_likelihoods
def _add_to_underlying_buffer( self, policy_id: PolicyID, batch: SampleBatchType, **kwargs ) -> None: """Add a batch of experiences to the underlying buffer of a policy. If the storage unit is `timesteps`, cut the batch into timeslices before adding them to the appropriate buffer. Otherwise, let the underlying buffer decide how slice batches. Args: policy_id: ID of the policy that corresponds to the underlying buffer batch: SampleBatch to add to the underlying buffer ``**kwargs``: Forward compatibility kwargs. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) # For the storage unit `timesteps`, the underlying buffer will # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self.storage_unit is StorageUnit.TIMESTEPS: timeslices = batch.timeslices(1) elif self.storage_unit is StorageUnit.SEQUENCES: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=batch, seq_lens=batch.get(SampleBatch.SEQ_LENS) if self.replay_sequence_override else None, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) elif self.storage_unit == StorageUnit.EPISODES: timeslices = [] for eps in batch.split_by_episode(): if ( eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer timeslices.append(eps) else: if log_once("only_full_episodes"): logger.info( "This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped." ) elif self.storage_unit == StorageUnit.FRAGMENTS: timeslices = [batch] else: raise ValueError("Unknown `storage_unit={}`".format(self.storage_unit)) for slice in timeslices: self.replay_buffers[policy_id].add(slice, **kwargs)
def action_prob(self, batch: SampleBatchType) -> np.ndarray: """Returns the probs for the batch actions for the current policy.""" num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): num_state_inputs += 1 state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] log_likelihoods: TensorType = self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], obs_batch=batch[SampleBatch.CUR_OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS)) log_likelihoods = convert_to_numpy(log_likelihoods) return np.exp(log_likelihoods)
def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences to this buffer. Also splits experiences into chunks of timesteps, sequences or episodes, depending on self._storage_unit. Calls self._add_single_batch. Args: batch: Batch to add to this buffer's storage. **kwargs: Forward compatibility kwargs. """ assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) if (type(batch) == MultiAgentBatch and self._storage_unit != StorageUnit.TIMESTEPS): raise ValueError("Can not add MultiAgentBatch to ReplayBuffer " "with storage_unit {}" "".format(str(self._storage_unit))) if self._storage_unit == StorageUnit.TIMESTEPS: self._add_single_batch(batch, **kwargs) elif self._storage_unit == StorageUnit.SEQUENCES: timestep_count = 0 for seq_len in batch.get(SampleBatch.SEQ_LENS): start_seq = timestep_count end_seq = timestep_count + seq_len self._add_single_batch(batch[start_seq:end_seq], **kwargs) timestep_count = end_seq elif self._storage_unit == StorageUnit.EPISODES: for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer self._add_single_batch(eps, **kwargs) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") elif self._storage_unit == StorageUnit.FRAGMENTS: self._add_single_batch(batch, **kwargs)
def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences to this buffer. Splits batch into chunks of timesteps, sequences or episodes, depending on `self._storage_unit`. Calls `self._add_single_batch` to add resulting slices to the buffer storage. Args: batch: Batch to add. ``**kwargs``: Forward compatibility kwargs. """ if not batch.count > 0: return warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) if self.storage_unit == StorageUnit.TIMESTEPS: timeslices = batch.timeslices(1) for t in timeslices: self._add_single_batch(t, **kwargs) elif self.storage_unit == StorageUnit.SEQUENCES: timestep_count = 0 for seq_len in batch.get(SampleBatch.SEQ_LENS): start_seq = timestep_count end_seq = timestep_count + seq_len self._add_single_batch(batch[start_seq:end_seq], **kwargs) timestep_count = end_seq elif self.storage_unit == StorageUnit.EPISODES: for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer self._add_single_batch(eps, **kwargs) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") elif self.storage_unit == StorageUnit.FRAGMENTS: self._add_single_batch(batch, **kwargs)
def timeslice_along_seq_lens_with_overlap( sample_batch: SampleBatchType, seq_lens: Optional[List[int]] = None, zero_pad_max_seq_len: int = 0, pre_overlap: int = 0, zero_init_states: bool = True, ) -> List["SampleBatch"]: """Slices batch along `seq_lens` (each seq-len item produces one batch). Args: sample_batch: The SampleBatch to timeslice. seq_lens (Optional[List[int]]): An optional list of seq_lens to slice at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`. zero_pad_max_seq_len: If >0, already zero-pad the resulting slices up to this length. NOTE: This max-len will include the additional timesteps gained via setting pre_overlap (see Example). pre_overlap: If >0, will overlap each two consecutive slices by this many timesteps (toward the left side). This will cause zero-padding at the very beginning of the batch. zero_init_states: Whether initial states should always be zero'd. If False, will use the state_outs of the batch to populate state_in values. Returns: List[SampleBatch]: The list of (new) SampleBatches. Examples: assert seq_lens == [5, 5, 2] assert sample_batch.count == 12 # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps slices = timeslice_along_seq_lens_with_overlap( sample_batch=sample_batch. zero_pad_max_seq_len=10, pre_overlap=3) # Z = zero padding (at beginning or end). # |pre (3)| seq | max-seq-len (up to 10) # slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z # slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps # count (makes sure each slice has exactly length 10). """ if seq_lens is None: seq_lens = sample_batch.get(SampleBatch.SEQ_LENS) else: if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once( "overriding_sequencing_information" ): logger.warning( "Found sequencing information in a batch that will be " "ignored when slicing. Ignore this warning if you know " "what you are doing." ) if seq_lens is None: max_seq_len = zero_pad_max_seq_len - pre_overlap if log_once("no_sequence_lengths_available_for_time_slicing"): logger.warning( "Trying to slice a batch along sequences without " "sequence lengths being provided in the batch. Batch will " "be sliced into slices of size " "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format( max_seq_len, zero_pad_max_seq_len, pre_overlap ) ) num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len) seq_lens = [zero_pad_max_seq_len] * num_seq_lens + ( [last_seq_len] if last_seq_len else [] ) assert ( seq_lens is not None and len(seq_lens) > 0 ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!" # Generate n slices based on seq_lens. start = 0 slices = [] for seq_len in seq_lens: pre_begin = start - pre_overlap slice_begin = start end = start + seq_len slices.append((pre_begin, slice_begin, end)) start += seq_len timeslices = [] for begin, slice_begin, end in slices: zero_length = None data_begin = 0 zero_init_states_ = zero_init_states if begin < 0: zero_length = pre_overlap data_begin = slice_begin zero_init_states_ = True else: eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end] is_last_episode_ids = eps_ids == eps_ids[-1] if not is_last_episode_ids[0]: zero_length = int(sum(1.0 - is_last_episode_ids)) data_begin = begin + zero_length zero_init_states_ = True if zero_length is not None: data = { k: np.concatenate( [ np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype), v[data_begin:end], ] ) for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } else: data = { k: v[begin:end] for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } if zero_init_states_: i = 0 key = "state_in_{}".format(i) while key in data: data[key] = np.zeros_like(sample_batch[key][0:1]) # Del state_out_n from data if exists. data.pop("state_out_{}".format(i), None) i += 1 key = "state_in_{}".format(i) # TODO: This will not work with attention nets as their state_outs are # not compatible with state_ins. else: i = 0 key = "state_in_{}".format(i) while key in data: data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin] del data["state_out_{}".format(i)] i += 1 key = "state_in_{}".format(i) timeslices.append(SampleBatch(data, seq_lens=[end - begin])) # Zero-pad each slice if necessary. if zero_pad_max_seq_len > 0: for ts in timeslices: ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True) return timeslices
def _add_to_underlying_buffer(self, policy_id: PolicyID, batch: SampleBatchType, **kwargs) -> None: """Add a batch of experiences to the underlying buffer of a policy. If the storage unit is `timesteps`, cut the batch into timeslices before adding them to the appropriate buffer. Otherwise, let the underlying buffer decide how slice batches. Args: policy_id: ID of the policy that corresponds to the underlying buffer batch: SampleBatch to add to the underlying buffer ``**kwargs``: Forward compatibility kwargs. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) # For the storage unit `timesteps`, the underlying buffer will # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self.storage_unit is StorageUnit.TIMESTEPS: timeslices = batch.timeslices(1) elif self.storage_unit is StorageUnit.SEQUENCES: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=batch, seq_lens=batch.get(SampleBatch.SEQ_LENS) if self.replay_sequence_override else None, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) elif self.storage_unit == StorageUnit.EPISODES: timeslices = [] for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer timeslices.append(eps) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") elif self.storage_unit == StorageUnit.FRAGMENTS: timeslices = [batch] else: raise ValueError("Unknown `storage_unit={}`".format( self.storage_unit)) for slice in timeslices: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. if self.replay_mode is ReplayMode.INDEPENDENT: if "weights" in slice and len(slice["weights"]): weight = np.mean(slice["weights"]) else: weight = None if "weight" in kwargs and weight is not None: if log_once("overwrite_weight"): logger.warning("Adding batches with column " "`weights` to this buffer while " "providing weights as a call argument " "to the add method results in the " "column being overwritten.") kwargs = {"weight": weight, **kwargs} else: if "weight" in kwargs: if log_once("lockstep_no_weight_allowed"): logger.warning("Settings weights for batches in " "lockstep mode is not allowed." "Weights are being ignored.") kwargs = {**kwargs, "weight": None} self.replay_buffers[policy_id].add(slice, **kwargs)