Exemplo n.º 1
0
 def concat_key(*values):
     return concat_aligned(values, time_major)
Exemplo n.º 2
0
    def concat_samples(
        samples: Union[List["SampleBatch"], List["MultiAgentBatch"]],
    ) -> Union["SampleBatch", "MultiAgentBatch"]:
        """Concatenates n SampleBatches or MultiAgentBatches.

        Args:
            samples: List of SampleBatches or MultiAgentBatches to be
                concatenated.

        Returns:
            A new (concatenated) SampleBatch or MultiAgentBatch.

        Examples:
            >>> import numpy as np
            >>> from ray.rllib.policy.sample_batch import SampleBatch
            >>> b1 = SampleBatch({"a": np.array([1, 2]), # doctest: +SKIP
            ...                   "b": np.array([10, 11])})
            >>> b2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
            ...                   "b": np.array([12])})
            >>> print(SampleBatch.concat_samples([b1, b2])) # doctest: +SKIP
            {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
        """
        if any(isinstance(s, MultiAgentBatch) for s in samples):
            return MultiAgentBatch.concat_samples(samples)
        concatd_seq_lens = []
        concat_samples = []
        zero_padded = samples[0].zero_padded
        max_seq_len = samples[0].max_seq_len
        time_major = samples[0].time_major
        for s in samples:
            if s.count > 0:
                assert s.zero_padded == zero_padded
                assert s.time_major == time_major
                if (s.max_seq_len is None or
                        max_seq_len is None) and s.max_seq_len != max_seq_len:
                    raise ValueError(
                        "Samples must consistently provide or omit max_seq_len"
                    )
                if zero_padded:
                    assert s.max_seq_len == max_seq_len
                if max_seq_len is not None:
                    max_seq_len = max(max_seq_len, s.max_seq_len)
                concat_samples.append(s)
                if s.get(SampleBatch.SEQ_LENS) is not None:
                    concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])

        # If we don't have any samples (0 or only empty SampleBatches),
        # return an empty SampleBatch here.
        if len(concat_samples) == 0:
            return SampleBatch()

        # Collect the concat'd data.
        concatd_data = {}

        def concat_key(*values):
            return concat_aligned(values, time_major)

        try:
            for k in concat_samples[0].keys():
                if k == "infos":
                    concatd_data[k] = concat_aligned(
                        [s[k] for s in concat_samples], time_major=time_major)
                else:
                    concatd_data[k] = tree.map_structure(
                        concat_key, *[c[k] for c in concat_samples])
        except Exception:
            raise ValueError(f"Cannot concat data under key '{k}', b/c "
                             "sub-structures under that key don't match. "
                             f"`samples`={samples}")

        # Return a new (concat'd) SampleBatch.
        return SampleBatch(
            concatd_data,
            seq_lens=concatd_seq_lens,
            _time_major=time_major,
            _zero_padded=zero_padded,
            _max_seq_len=max_seq_len,
        )