def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str: out = {} if isinstance(batch, MultiAgentBatch): out["type"] = "MultiAgentBatch" out["count"] = batch.count policy_batches = {} for policy_id, sub_batch in batch.policy_batches.items(): policy_batches[policy_id] = {} for k, v in sub_batch.items(): policy_batches[policy_id][k] = _to_jsonable( v, compress=k in compress_columns) out["policy_batches"] = policy_batches else: out["type"] = "SampleBatch" for k, v in batch.items(): out[k] = _to_jsonable(v, compress=k in compress_columns) return json.dumps(out)
def timeslice_along_seq_lens_with_overlap( sample_batch: SampleBatchType, seq_lens: Optional[List[int]] = None, zero_pad_max_seq_len: int = 0, pre_overlap: int = 0, zero_init_states: bool = True, ) -> List["SampleBatch"]: """Slices batch along `seq_lens` (each seq-len item produces one batch). Args: sample_batch: The SampleBatch to timeslice. seq_lens (Optional[List[int]]): An optional list of seq_lens to slice at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`. zero_pad_max_seq_len: If >0, already zero-pad the resulting slices up to this length. NOTE: This max-len will include the additional timesteps gained via setting pre_overlap (see Example). pre_overlap: If >0, will overlap each two consecutive slices by this many timesteps (toward the left side). This will cause zero-padding at the very beginning of the batch. zero_init_states: Whether initial states should always be zero'd. If False, will use the state_outs of the batch to populate state_in values. Returns: List[SampleBatch]: The list of (new) SampleBatches. Examples: assert seq_lens == [5, 5, 2] assert sample_batch.count == 12 # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps slices = timeslice_along_seq_lens_with_overlap( sample_batch=sample_batch. zero_pad_max_seq_len=10, pre_overlap=3) # Z = zero padding (at beginning or end). # |pre (3)| seq | max-seq-len (up to 10) # slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z # slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps # count (makes sure each slice has exactly length 10). """ if seq_lens is None: seq_lens = sample_batch.get(SampleBatch.SEQ_LENS) else: if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once( "overriding_sequencing_information" ): logger.warning( "Found sequencing information in a batch that will be " "ignored when slicing. Ignore this warning if you know " "what you are doing." ) if seq_lens is None: max_seq_len = zero_pad_max_seq_len - pre_overlap if log_once("no_sequence_lengths_available_for_time_slicing"): logger.warning( "Trying to slice a batch along sequences without " "sequence lengths being provided in the batch. Batch will " "be sliced into slices of size " "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format( max_seq_len, zero_pad_max_seq_len, pre_overlap ) ) num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len) seq_lens = [zero_pad_max_seq_len] * num_seq_lens + ( [last_seq_len] if last_seq_len else [] ) assert ( seq_lens is not None and len(seq_lens) > 0 ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!" # Generate n slices based on seq_lens. start = 0 slices = [] for seq_len in seq_lens: pre_begin = start - pre_overlap slice_begin = start end = start + seq_len slices.append((pre_begin, slice_begin, end)) start += seq_len timeslices = [] for begin, slice_begin, end in slices: zero_length = None data_begin = 0 zero_init_states_ = zero_init_states if begin < 0: zero_length = pre_overlap data_begin = slice_begin zero_init_states_ = True else: eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end] is_last_episode_ids = eps_ids == eps_ids[-1] if not is_last_episode_ids[0]: zero_length = int(sum(1.0 - is_last_episode_ids)) data_begin = begin + zero_length zero_init_states_ = True if zero_length is not None: data = { k: np.concatenate( [ np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype), v[data_begin:end], ] ) for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } else: data = { k: v[begin:end] for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } if zero_init_states_: i = 0 key = "state_in_{}".format(i) while key in data: data[key] = np.zeros_like(sample_batch[key][0:1]) # Del state_out_n from data if exists. data.pop("state_out_{}".format(i), None) i += 1 key = "state_in_{}".format(i) # TODO: This will not work with attention nets as their state_outs are # not compatible with state_ins. else: i = 0 key = "state_in_{}".format(i) while key in data: data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin] del data["state_out_{}".format(i)] i += 1 key = "state_in_{}".format(i) timeslices.append(SampleBatch(data, seq_lens=[end - begin])) # Zero-pad each slice if necessary. if zero_pad_max_seq_len > 0: for ts in timeslices: ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True) return timeslices