Esempio n. 1
0
    def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
        """Returns log likelihood for actions in given batch for policy.

        Computes likelihoods by passing the observations through the current
        policy's `compute_log_likelihoods()` method

        Args:
            batch: The SampleBatch or MultiAgentBatch to calculate action
                log likelihoods from. This batch/batches must contain OBS
                and ACTIONS keys.

        Returns:
            The probabilities of the actions in the batch, given the
            observations and the policy.
        """
        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
        log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
            actions_normalized=True,
        )
        log_likelihoods = convert_to_numpy(log_likelihoods)
        return log_likelihoods
Esempio n. 2
0
    def _add_to_underlying_buffer(
        self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
    ) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            ``**kwargs``: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self.storage_unit is StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
        elif self.storage_unit is StorageUnit.SEQUENCES:
            timeslices = timeslice_along_seq_lens_with_overlap(
                sample_batch=batch,
                seq_lens=batch.get(SampleBatch.SEQ_LENS)
                if self.replay_sequence_override
                else None,
                zero_pad_max_seq_len=self.replay_sequence_length,
                pre_overlap=self.replay_burn_in,
                zero_init_states=self.replay_zero_init_states,
            )
        elif self.storage_unit == StorageUnit.EPISODES:
            timeslices = []
            for eps in batch.split_by_episode():
                if (
                    eps.get(SampleBatch.T)[0] == 0
                    and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                ):
                    # Only add full episodes to the buffer
                    timeslices.append(eps)
                else:
                    if log_once("only_full_episodes"):
                        logger.info(
                            "This buffer uses episodes as a storage "
                            "unit and thus allows only full episodes "
                            "to be added to it. Some samples may be "
                            "dropped."
                        )
        elif self.storage_unit == StorageUnit.FRAGMENTS:
            timeslices = [batch]
        else:
            raise ValueError("Unknown `storage_unit={}`".format(self.storage_unit))

        for slice in timeslices:
            self.replay_buffers[policy_id].add(slice, **kwargs)
Esempio n. 3
0
    def action_prob(self, batch: SampleBatchType) -> np.ndarray:
        """Returns the probs for the batch actions for the current policy."""

        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
        log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.CUR_OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS))
        log_likelihoods = convert_to_numpy(log_likelihoods)
        return np.exp(log_likelihoods)
Esempio n. 4
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences to this buffer.

        Also splits experiences into chunks of timesteps, sequences
        or episodes, depending on self._storage_unit. Calls
        self._add_single_batch.

        Args:
            batch: Batch to add to this buffer's storage.
            **kwargs: Forward compatibility kwargs.
        """
        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        if (type(batch) == MultiAgentBatch
                and self._storage_unit != StorageUnit.TIMESTEPS):
            raise ValueError("Can not add MultiAgentBatch to ReplayBuffer "
                             "with storage_unit {}"
                             "".format(str(self._storage_unit)))

        if self._storage_unit == StorageUnit.TIMESTEPS:
            self._add_single_batch(batch, **kwargs)

        elif self._storage_unit == StorageUnit.SEQUENCES:
            timestep_count = 0
            for seq_len in batch.get(SampleBatch.SEQ_LENS):
                start_seq = timestep_count
                end_seq = timestep_count + seq_len
                self._add_single_batch(batch[start_seq:end_seq], **kwargs)
                timestep_count = end_seq

        elif self._storage_unit == StorageUnit.EPISODES:
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    self._add_single_batch(eps, **kwargs)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")
        elif self._storage_unit == StorageUnit.FRAGMENTS:
            self._add_single_batch(batch, **kwargs)
Esempio n. 5
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences to this buffer.

        Splits batch into chunks of timesteps, sequences or episodes, depending on
        `self._storage_unit`. Calls `self._add_single_batch` to add resulting slices
        to the buffer storage.

        Args:
            batch: Batch to add.
            ``**kwargs``: Forward compatibility kwargs.
        """
        if not batch.count > 0:
            return

        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        if self.storage_unit == StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
            for t in timeslices:
                self._add_single_batch(t, **kwargs)

        elif self.storage_unit == StorageUnit.SEQUENCES:
            timestep_count = 0
            for seq_len in batch.get(SampleBatch.SEQ_LENS):
                start_seq = timestep_count
                end_seq = timestep_count + seq_len
                self._add_single_batch(batch[start_seq:end_seq], **kwargs)
                timestep_count = end_seq

        elif self.storage_unit == StorageUnit.EPISODES:
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    self._add_single_batch(eps, **kwargs)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")

        elif self.storage_unit == StorageUnit.FRAGMENTS:
            self._add_single_batch(batch, **kwargs)
Esempio n. 6
0
def timeslice_along_seq_lens_with_overlap(
    sample_batch: SampleBatchType,
    seq_lens: Optional[List[int]] = None,
    zero_pad_max_seq_len: int = 0,
    pre_overlap: int = 0,
    zero_init_states: bool = True,
) -> List["SampleBatch"]:
    """Slices batch along `seq_lens` (each seq-len item produces one batch).

    Args:
        sample_batch: The SampleBatch to timeslice.
        seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
            at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`.
        zero_pad_max_seq_len: If >0, already zero-pad the resulting
            slices up to this length. NOTE: This max-len will include the
            additional timesteps gained via setting pre_overlap (see Example).
        pre_overlap: If >0, will overlap each two consecutive slices by
            this many timesteps (toward the left side). This will cause
            zero-padding at the very beginning of the batch.
        zero_init_states: Whether initial states should always be
            zero'd. If False, will use the state_outs of the batch to
            populate state_in values.

    Returns:
        List[SampleBatch]: The list of (new) SampleBatches.

    Examples:
        assert seq_lens == [5, 5, 2]
        assert sample_batch.count == 12
        # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
        slices = timeslice_along_seq_lens_with_overlap(
            sample_batch=sample_batch.
            zero_pad_max_seq_len=10,
            pre_overlap=3)
        # Z = zero padding (at beginning or end).
        #             |pre (3)|     seq     | max-seq-len (up to 10)
        # slices[0] = | Z Z Z |  0  1 2 3 4 | Z Z
        # slices[1] = | 2 3 4 |  5  6 7 8 9 | Z Z
        # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
        # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
        #  count (makes sure each slice has exactly length 10).
    """
    if seq_lens is None:
        seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
    else:
        if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once(
            "overriding_sequencing_information"
        ):
            logger.warning(
                "Found sequencing information in a batch that will be "
                "ignored when slicing. Ignore this warning if you know "
                "what you are doing."
            )

    if seq_lens is None:
        max_seq_len = zero_pad_max_seq_len - pre_overlap
        if log_once("no_sequence_lengths_available_for_time_slicing"):
            logger.warning(
                "Trying to slice a batch along sequences without "
                "sequence lengths being provided in the batch. Batch will "
                "be sliced into slices of size "
                "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
                    max_seq_len, zero_pad_max_seq_len, pre_overlap
                )
            )
        num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
        seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
            [last_seq_len] if last_seq_len else []
        )

    assert (
        seq_lens is not None and len(seq_lens) > 0
    ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
    # Generate n slices based on seq_lens.
    start = 0
    slices = []
    for seq_len in seq_lens:
        pre_begin = start - pre_overlap
        slice_begin = start
        end = start + seq_len
        slices.append((pre_begin, slice_begin, end))
        start += seq_len

    timeslices = []
    for begin, slice_begin, end in slices:
        zero_length = None
        data_begin = 0
        zero_init_states_ = zero_init_states
        if begin < 0:
            zero_length = pre_overlap
            data_begin = slice_begin
            zero_init_states_ = True
        else:
            eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end]
            is_last_episode_ids = eps_ids == eps_ids[-1]
            if not is_last_episode_ids[0]:
                zero_length = int(sum(1.0 - is_last_episode_ids))
                data_begin = begin + zero_length
                zero_init_states_ = True

        if zero_length is not None:
            data = {
                k: np.concatenate(
                    [
                        np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype),
                        v[data_begin:end],
                    ]
                )
                for k, v in sample_batch.items()
                if k != SampleBatch.SEQ_LENS
            }
        else:
            data = {
                k: v[begin:end]
                for k, v in sample_batch.items()
                if k != SampleBatch.SEQ_LENS
            }

        if zero_init_states_:
            i = 0
            key = "state_in_{}".format(i)
            while key in data:
                data[key] = np.zeros_like(sample_batch[key][0:1])
                # Del state_out_n from data if exists.
                data.pop("state_out_{}".format(i), None)
                i += 1
                key = "state_in_{}".format(i)
        # TODO: This will not work with attention nets as their state_outs are
        #  not compatible with state_ins.
        else:
            i = 0
            key = "state_in_{}".format(i)
            while key in data:
                data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin]
                del data["state_out_{}".format(i)]
                i += 1
                key = "state_in_{}".format(i)

        timeslices.append(SampleBatch(data, seq_lens=[end - begin]))

    # Zero-pad each slice if necessary.
    if zero_pad_max_seq_len > 0:
        for ts in timeslices:
            ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)

    return timeslices
Esempio n. 7
0
    def _add_to_underlying_buffer(self, policy_id: PolicyID,
                                  batch: SampleBatchType, **kwargs) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            ``**kwargs``: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self.storage_unit is StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
        elif self.storage_unit is StorageUnit.SEQUENCES:
            timeslices = timeslice_along_seq_lens_with_overlap(
                sample_batch=batch,
                seq_lens=batch.get(SampleBatch.SEQ_LENS)
                if self.replay_sequence_override else None,
                zero_pad_max_seq_len=self.replay_sequence_length,
                pre_overlap=self.replay_burn_in,
                zero_init_states=self.replay_zero_init_states,
            )
        elif self.storage_unit == StorageUnit.EPISODES:
            timeslices = []
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    timeslices.append(eps)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")
        elif self.storage_unit == StorageUnit.FRAGMENTS:
            timeslices = [batch]
        else:
            raise ValueError("Unknown `storage_unit={}`".format(
                self.storage_unit))

        for slice in timeslices:
            # If SampleBatch has prio-replay weights, average
            # over these to use as a weight for the entire
            # sequence.
            if self.replay_mode is ReplayMode.INDEPENDENT:
                if "weights" in slice and len(slice["weights"]):
                    weight = np.mean(slice["weights"])
                else:
                    weight = None

                if "weight" in kwargs and weight is not None:
                    if log_once("overwrite_weight"):
                        logger.warning("Adding batches with column "
                                       "`weights` to this buffer while "
                                       "providing weights as a call argument "
                                       "to the add method results in the "
                                       "column being overwritten.")

                kwargs = {"weight": weight, **kwargs}
            else:
                if "weight" in kwargs:
                    if log_once("lockstep_no_weight_allowed"):
                        logger.warning("Settings weights for batches in "
                                       "lockstep mode is not allowed."
                                       "Weights are being ignored.")

                kwargs = {**kwargs, "weight": None}
            self.replay_buffers[policy_id].add(slice, **kwargs)