예제 #1
0
    def concat_samples(samples: List[Dict[str, TensorType]]) -> \
            Union["SampleBatch", "MultiAgentBatch"]:
        """Concatenates n data dicts or MultiAgentBatches.

        Args:
            samples (List[Dict[TensorType]]]): List of dicts of data (numpy).

        Returns:
            Union[SampleBatch, MultiAgentBatch]: A new (compressed)
                SampleBatch or MultiAgentBatch.
        """
        if isinstance(samples[0], MultiAgentBatch):
            return MultiAgentBatch.concat_samples(samples)
        seq_lens = []
        concat_samples = []
        for s in samples:
            if s.count > 0:
                concat_samples.append(s)
                if s.seq_lens is not None:
                    seq_lens.extend(s.seq_lens)

        out = {}
        for k in concat_samples[0].keys():
            out[k] = concat_aligned([s[k] for s in concat_samples],
                                    time_major=concat_samples[0].time_major)
        return SampleBatch(out,
                           _seq_lens=seq_lens,
                           _time_major=concat_samples[0].time_major)
예제 #2
0
    def concat(self, other: "SampleBatch") -> "SampleBatch":
        """Returns a new SampleBatch with each data column concatenated.

        Args:
            other (SampleBatch): The other SampleBatch object to concat to this
                one.

        Returns:
            SampleBatch: The new SampleBatch, resulting from concating `other`
                to `self`.

        Examples:
            >>> b1 = SampleBatch({"a": [1, 2]})
            >>> b2 = SampleBatch({"a": [3, 4, 5]})
            >>> print(b1.concat(b2))
            {"a": [1, 2, 3, 4, 5]}
        """

        if self.keys() != other.keys():
            raise ValueError(
                "SampleBatches to concat must have same columns! {} vs {}".
                format(list(self.keys()), list(other.keys())))
        out = {}
        for k in self.keys():
            out[k] = concat_aligned([self[k], other[k]])
        return SampleBatch(out)
예제 #3
0
 def concat_samples(samples):
     if isinstance(samples[0], MultiAgentBatch):
         return MultiAgentBatch.concat_samples(samples)
     out = {}
     samples = [s for s in samples if s.count > 0]
     for k in samples[0].keys():
         out[k] = concat_aligned([s[k] for s in samples])
     return SampleBatch(out)
예제 #4
0
    def concat(self, other):
        """Returns a new SampleBatch with each data column concatenated.

        Examples:
            >>> b1 = SampleBatch({"a": [1, 2]})
            >>> b2 = SampleBatch({"a": [3, 4, 5]})
            >>> print(b1.concat(b2))
            {"a": [1, 2, 3, 4, 5]}
        """

        assert self.keys() == other.keys(), "must have same columns"
        out = {}
        for k in self.keys():
            out[k] = concat_aligned([self[k], other[k]])
        return SampleBatch(out)
예제 #5
0
    def concat_samples(samples: List["SampleBatch"]) -> \
            Union["SampleBatch", "MultiAgentBatch"]:
        """Concatenates n data dicts or MultiAgentBatches.

        Args:
            samples (List[Dict[str, TensorType]]]): List of dicts of data
                (numpy).

        Returns:
            Union[SampleBatch, MultiAgentBatch]: A new (compressed)
                SampleBatch or MultiAgentBatch.
        """
        if isinstance(samples[0], MultiAgentBatch):
            return MultiAgentBatch.concat_samples(samples)
        seq_lens = []
        concat_samples = []
        zero_padded = samples[0].zero_padded
        max_seq_len = samples[0].max_seq_len
        for s in samples:
            if s.count > 0:
                assert s.zero_padded == zero_padded
                if zero_padded:
                    assert s.max_seq_len == max_seq_len
                concat_samples.append(s)
                if s.get("seq_lens") is not None:
                    seq_lens.extend(s["seq_lens"])

        # If we don't have any samples (no or only empty SampleBatches),
        # return an empty SampleBatch here.
        if len(concat_samples) == 0:
            return SampleBatch()

        # Collect the concat'd data.
        concatd_data = {}
        for k in concat_samples[0].keys():
            concatd_data[k] = concat_aligned(
                [s[k] for s in concat_samples],
                time_major=concat_samples[0].time_major)

        # Return a new (concat'd) SampleBatch.
        return SampleBatch(
            concatd_data,
            seq_lens=seq_lens,
            _time_major=concat_samples[0].time_major,
            _zero_padded=zero_padded,
            _max_seq_len=max_seq_len,
        )
예제 #6
0
    def concat_samples(samples):
        """Concatenates n data dicts or MultiAgentBatches.

        Args:
            samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy).

        Returns:
            Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/
                MultiAgentBatch.
        """
        if isinstance(samples[0], MultiAgentBatch):
            return MultiAgentBatch.concat_samples(samples)
        out = {}
        samples = [s for s in samples if s.count > 0]
        for k in samples[0].keys():
            out[k] = concat_aligned([s[k] for s in samples])
        return SampleBatch(out)
예제 #7
0
    def concat(self, other):
        """Returns a new SampleBatch with each data column concatenated.

        Examples:
            >>> b1 = SampleBatch({"a": [1, 2]})
            >>> b2 = SampleBatch({"a": [3, 4, 5]})
            >>> print(b1.concat(b2))
            {"a": [1, 2, 3, 4, 5]}
        """

        if self.keys() != other.keys():
            raise ValueError(
                "SampleBatches to concat must have same columns! {} vs {}".
                format(list(self.keys()), list(other.keys())))
        out = {}
        for k in self.keys():
            out[k] = concat_aligned([self[k], other[k]])
        return SampleBatch(out)
예제 #8
0
    def concat_samples(samples: List["SampleBatch"]) -> \
            Union["SampleBatch", "MultiAgentBatch"]:
        """Concatenates n data dicts or MultiAgentBatches.

        Args:
            samples (List[Dict[TensorType]]]): List of dicts of data (numpy).

        Returns:
            Union[SampleBatch, MultiAgentBatch]: A new (compressed)
                SampleBatch or MultiAgentBatch.
        """
        if isinstance(samples[0], MultiAgentBatch):
            return MultiAgentBatch.concat_samples(samples)
        seq_lens = []
        concat_samples = []
        zero_padded = samples[0].zero_padded
        max_seq_len = samples[0].max_seq_len
        for s in samples:
            if s.count > 0:
                assert s.zero_padded == zero_padded
                if zero_padded:
                    assert s.max_seq_len == max_seq_len
                concat_samples.append(s)
                if s.seq_lens is not None:
                    seq_lens.extend(s.seq_lens)

        out = {}
        for k in concat_samples[0].keys():
            out[k] = concat_aligned(
                [s[k] for s in concat_samples],
                time_major=concat_samples[0].time_major)
        return SampleBatch(
            out,
            _seq_lens=np.array(seq_lens, dtype=np.int32),
            _time_major=concat_samples[0].time_major,
            _dont_check_lens=True,
            _zero_padded=zero_padded,
            _max_seq_len=max_seq_len,
        )
예제 #9
0
 def concat_key(*values):
     return concat_aligned(values, time_major)
예제 #10
0
    def concat_samples(
        samples: Union[List["SampleBatch"], List["MultiAgentBatch"]],
    ) -> Union["SampleBatch", "MultiAgentBatch"]:
        """Concatenates n SampleBatches or MultiAgentBatches.

        Args:
            samples (Union[List[SampleBatch], List[MultiAgentBatch]]): List of
                SampleBatches or MultiAgentBatches to be concatenated.

        Returns:
            Union[SampleBatch, MultiAgentBatch]: A new (concatenated)
                SampleBatch or MultiAgentBatch.

        Examples:
            >>> b1 = SampleBatch({"a": np.array([1, 2]),
            ...                   "b": np.array([10, 11])})
            >>> b2 = SampleBatch({"a": np.array([3]),
            ...                   "b": np.array([12])})
            >>> print(SampleBatch.concat_samples([b1, b2]))
            {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
        """
        if isinstance(samples[0], MultiAgentBatch):
            return MultiAgentBatch.concat_samples(samples)
        concatd_seq_lens = []
        concat_samples = []
        zero_padded = samples[0].zero_padded
        max_seq_len = samples[0].max_seq_len
        time_major = samples[0].time_major
        for s in samples:
            if s.count > 0:
                assert s.zero_padded == zero_padded
                assert s.time_major == time_major
                if zero_padded:
                    assert s.max_seq_len == max_seq_len
                concat_samples.append(s)
                if s.get(SampleBatch.SEQ_LENS) is not None:
                    concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])

        # If we don't have any samples (0 or only empty SampleBatches),
        # return an empty SampleBatch here.
        if len(concat_samples) == 0:
            return SampleBatch()

        # Collect the concat'd data.
        concatd_data = {}

        def concat_key(*values):
            return concat_aligned(values, time_major)

        try:
            for k in concat_samples[0].keys():
                if k == "infos":
                    concatd_data[k] = concat_aligned(
                        [s[k] for s in concat_samples], time_major=time_major)
                else:
                    concatd_data[k] = tree.map_structure(
                        concat_key, *[c[k] for c in concat_samples])
        except Exception:
            raise ValueError(f"Cannot concat data under key '{k}', b/c "
                             "sub-structures under that key don't match. "
                             f"`samples`={samples}")

        # Return a new (concat'd) SampleBatch.
        return SampleBatch(
            concatd_data,
            seq_lens=concatd_seq_lens,
            _time_major=time_major,
            _zero_padded=zero_padded,
            _max_seq_len=max_seq_len,
        )