def concat_key(*values): return concat_aligned(values, time_major)
def concat_samples( samples: Union[List["SampleBatch"], List["MultiAgentBatch"]], ) -> Union["SampleBatch", "MultiAgentBatch"]: """Concatenates n SampleBatches or MultiAgentBatches. Args: samples: List of SampleBatches or MultiAgentBatches to be concatenated. Returns: A new (concatenated) SampleBatch or MultiAgentBatch. Examples: >>> import numpy as np >>> from ray.rllib.policy.sample_batch import SampleBatch >>> b1 = SampleBatch({"a": np.array([1, 2]), # doctest: +SKIP ... "b": np.array([10, 11])}) >>> b2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP ... "b": np.array([12])}) >>> print(SampleBatch.concat_samples([b1, b2])) # doctest: +SKIP {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])} """ if any(isinstance(s, MultiAgentBatch) for s in samples): return MultiAgentBatch.concat_samples(samples) concatd_seq_lens = [] concat_samples = [] zero_padded = samples[0].zero_padded max_seq_len = samples[0].max_seq_len time_major = samples[0].time_major for s in samples: if s.count > 0: assert s.zero_padded == zero_padded assert s.time_major == time_major if (s.max_seq_len is None or max_seq_len is None) and s.max_seq_len != max_seq_len: raise ValueError( "Samples must consistently provide or omit max_seq_len" ) if zero_padded: assert s.max_seq_len == max_seq_len if max_seq_len is not None: max_seq_len = max(max_seq_len, s.max_seq_len) concat_samples.append(s) if s.get(SampleBatch.SEQ_LENS) is not None: concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS]) # If we don't have any samples (0 or only empty SampleBatches), # return an empty SampleBatch here. if len(concat_samples) == 0: return SampleBatch() # Collect the concat'd data. concatd_data = {} def concat_key(*values): return concat_aligned(values, time_major) try: for k in concat_samples[0].keys(): if k == "infos": concatd_data[k] = concat_aligned( [s[k] for s in concat_samples], time_major=time_major) else: concatd_data[k] = tree.map_structure( concat_key, *[c[k] for c in concat_samples]) except Exception: raise ValueError(f"Cannot concat data under key '{k}', b/c " "sub-structures under that key don't match. " f"`samples`={samples}") # Return a new (concat'd) SampleBatch. return SampleBatch( concatd_data, seq_lens=concatd_seq_lens, _time_major=time_major, _zero_padded=zero_padded, _max_seq_len=max_seq_len, )