Esempio n. 1
0
def pad_batch_to_sequences_of_same_size(
    batch: SampleBatch,
    max_seq_len: int,
    shuffle: bool = False,
    batch_divisibility_req: int = 1,
    feature_keys: Optional[List[str]] = None,
    view_requirements: Optional[ViewRequirementsDict] = None,
):
    """Applies padding to `batch` so it's choppable into same-size sequences.

    Shuffles `batch` (if desired), makes sure divisibility requirement is met,
    then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
    adding a time dimension (yet).
    Padding depends on episodes found in batch and `max_seq_len`.

    Args:
        batch: The SampleBatch object. All values in here have
            the shape [B, ...].
        max_seq_len: The max. sequence length to use for chopping.
        shuffle: Whether to shuffle batch sequences. Shuffle may
            be done in-place. This only makes sense if you're further
            applying minibatch SGD after getting the outputs.
        batch_divisibility_req: The int by which the batch dimension
            must be dividable.
        feature_keys: An optional list of keys to apply sequence-chopping
            to. If None, use all keys in batch that are not
            "state_in/out_"-type keys.
        view_requirements: An optional Policy ViewRequirements dict to
            be able to infer whether e.g. dynamic max'ing should be
            applied over the seq_lens.
    """
    # If already zero-padded, skip.
    if batch.zero_padded:
        return

    batch.zero_padded = True

    if batch_divisibility_req > 1:
        meets_divisibility_reqs = (
            len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
            # not multiagent
            and max(batch[SampleBatch.AGENT_INDEX]) == 0
        )
    else:
        meets_divisibility_reqs = True

    states_already_reduced_to_init = False

    # RNN/attention net case. Figure out whether we should apply dynamic
    # max'ing over the list of sequence lengths.
    if "state_in_0" in batch or "state_out_0" in batch:
        # Check, whether the state inputs have already been reduced to their
        # init values at the beginning of each max_seq_len chunk.
        if batch.get(SampleBatch.SEQ_LENS) is not None and len(
            batch["state_in_0"]
        ) == len(batch[SampleBatch.SEQ_LENS]):
            states_already_reduced_to_init = True

        # RNN (or single timestep state-in): Set the max dynamically.
        if view_requirements["state_in_0"].shift_from is None:
            dynamic_max = True
        # Attention Nets (state inputs are over some range): No dynamic maxing
        # possible.
        else:
            dynamic_max = False
    # Multi-agent case.
    elif not meets_divisibility_reqs:
        max_seq_len = batch_divisibility_req
        dynamic_max = False
        batch.max_seq_len = max_seq_len
    # Simple case: No RNN/attention net, nor do we need to pad.
    else:
        if shuffle:
            batch.shuffle()
        return

    # RNN, attention net, or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k, v in batch.items():
        if k.startswith("state_in_"):
            state_keys.append(k)
        elif (
            not feature_keys
            and not k.startswith("state_out_")
            and k not in [SampleBatch.INFOS, SampleBatch.SEQ_LENS]
        ):
            feature_keys_.append(k)

    feature_sequences, initial_states, seq_lens = chop_into_sequences(
        feature_columns=[batch[k] for k in feature_keys_],
        state_columns=[batch[k] for k in state_keys],
        episode_ids=batch.get(SampleBatch.EPS_ID),
        unroll_ids=batch.get(SampleBatch.UNROLL_ID),
        agent_indices=batch.get(SampleBatch.AGENT_INDEX),
        seq_lens=batch.get(SampleBatch.SEQ_LENS),
        max_seq_len=max_seq_len,
        dynamic_max=dynamic_max,
        states_already_reduced_to_init=states_already_reduced_to_init,
        shuffle=shuffle,
        handle_nested_data=True,
    )

    for i, k in enumerate(feature_keys_):
        batch[k] = tree.unflatten_as(batch[k], feature_sequences[i])
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch[SampleBatch.SEQ_LENS] = np.array(seq_lens)
    if dynamic_max:
        batch.max_seq_len = max(seq_lens)

    if log_once("rnn_ma_feed_dict"):
        logger.info(
            "Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
                summarize(
                    {
                        "features": feature_sequences,
                        "initial_states": initial_states,
                        "seq_lens": seq_lens,
                        "max_seq_len": max_seq_len,
                    }
                )
            )
        )
Esempio n. 2
0
    def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
        """Builds a SampleBatch from the thus-far collected agent data.

        If the episode/trajectory has no DONE=True at the end, will copy
        the necessary n timesteps at the end of the trajectory back to the
        beginning of the buffers and wait for new samples coming in.
        SampleBatches created by this method will be ready for postprocessing
        by a Policy.

        Args:
            view_requirements (ViewRequirementsDict): The view
                requirements dict needed to build the SampleBatch from the raw
                buffers (which may have data shifts as well as mappings from
                view-col to data-col in them).

        Returns:
            SampleBatch: The built SampleBatch for this agent, ready to go into
                postprocessing.
        """

        batch_data = {}
        np_data = {}
        for view_col, view_req in view_requirements.items():
            # Create the batch of data from the different buffers.
            data_col = view_req.data_col or view_col

            # Some columns don't exist yet (get created during postprocessing).
            # -> skip.
            if data_col not in self.buffers:
                continue

            # OBS are already shifted by -1 (the initial obs starts one ts
            # before all other data columns).
            obs_shift = -1 if data_col == SampleBatch.OBS else 0

            # Keep an np-array cache so we don't have to regenerate the
            # np-array for different view_cols using to the same data_col.
            if data_col not in np_data:
                np_data[data_col] = [
                    to_float_np_array(d) for d in self.buffers[data_col]
                ]

            # Range of indices on time-axis, e.g. "-50:-1". Together with
            # the `batch_repeat_value`, this determines the data produced.
            # Example:
            #  batch_repeat_value=10, shift_from=-3, shift_to=-1
            #  buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
            #  resulting data=[[-3, -2, -1], [7, 8, 9]]
            #  Range of 3 consecutive items repeats every 10 timesteps.
            if view_req.shift_from is not None:
                # Batch repeat value > 1: Only repeat the shift_from/to range
                # every n timesteps.
                if view_req.batch_repeat_value > 1:
                    count = int(
                        math.ceil(
                            (len(np_data[data_col][0]) - self.shift_before) /
                            view_req.batch_repeat_value))
                    data = [
                        np.asarray([
                            d[self.shift_before +
                              (i * view_req.batch_repeat_value) +
                              view_req.shift_from +
                              obs_shift:self.shift_before +
                              (i * view_req.batch_repeat_value) +
                              view_req.shift_to + 1 + obs_shift]
                            for i in range(count)
                        ]) for d in np_data[data_col]
                    ]
                # Batch repeat value = 1: Repeat the shift_from/to range at
                # each timestep.
                else:
                    d0 = np_data[data_col][0]
                    shift_win = view_req.shift_to - view_req.shift_from + 1
                    data_size = d0.itemsize * int(np.product(d0.shape[1:]))
                    strides = [
                        d0.itemsize * int(np.product(d0.shape[i + 1:]))
                        for i in range(1, len(d0.shape))
                    ]
                    start = (self.shift_before - shift_win + 1 + obs_shift +
                             view_req.shift_to)
                    data = [
                        np.lib.stride_tricks.as_strided(
                            d[start:start + self.agent_steps],
                            [self.agent_steps, shift_win] +
                            [d.shape[i] for i in range(1, len(d.shape))],
                            [data_size, data_size] + strides,
                        ) for d in np_data[data_col]
                    ]
            # Set of (probably non-consecutive) indices.
            # Example:
            #  shift=[-3, 0]
            #  buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
            #  resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
            elif isinstance(view_req.shift, np.ndarray):
                data = [
                    d[self.shift_before + obs_shift + view_req.shift]
                    for d in np_data[data_col]
                ]
            # Single shift int value. Use the trajectory as-is, and if
            # `shift` != 0: shifted by that value.
            else:
                shift = view_req.shift + obs_shift

                # Batch repeat (only provide a value every n timesteps).
                if view_req.batch_repeat_value > 1:
                    count = int(
                        math.ceil(
                            (len(np_data[data_col][0]) - self.shift_before) /
                            view_req.batch_repeat_value))
                    data = [
                        np.asarray([
                            d[self.shift_before +
                              (i * view_req.batch_repeat_value) + shift]
                            for i in range(count)
                        ]) for d in np_data[data_col]
                    ]
                # Shift is exactly 0: Use trajectory as is.
                elif shift == 0:
                    data = [d[self.shift_before:] for d in np_data[data_col]]
                # Shift is positive: We still need to 0-pad at the end.
                elif shift > 0:
                    data = [
                        to_float_np_array(
                            np.concatenate([
                                d[self.shift_before + shift:],
                                [
                                    np.zeros(
                                        shape=view_req.space.shape,
                                        dtype=view_req.space.dtype,
                                    ) for _ in range(shift)
                                ],
                            ])) for d in np_data[data_col]
                    ]
                # Shift is negative: Shift into the already existing and
                # 0-padded "before" area of our buffers.
                else:
                    data = [
                        d[self.shift_before + shift:shift]
                        for d in np_data[data_col]
                    ]

            if len(data) > 0:
                if data_col not in self.buffer_structs:
                    batch_data[view_col] = data[0]
                else:
                    batch_data[view_col] = tree.unflatten_as(
                        self.buffer_structs[data_col], data)

        # Due to possible batch-repeats > 1, columns in the resulting batch
        # may not all have the same batch size.
        batch = SampleBatch(batch_data)

        # Adjust the seq-lens array depending on the incoming agent sequences.
        if self.policy.is_recurrent():
            seq_lens = []
            max_seq_len = self.policy.config["model"]["max_seq_len"]
            count = batch.count
            while count > 0:
                seq_lens.append(min(count, max_seq_len))
                count -= max_seq_len
            batch["seq_lens"] = np.array(seq_lens)
            batch.max_seq_len = max_seq_len

        # This trajectory is continuing -> Copy data at the end (in the size of
        # self.shift_before) to the beginning of buffers and erase everything
        # else.
        if not self.buffers[SampleBatch.DONES][0][-1]:
            # Copy data to beginning of buffer and cut lists.
            if self.shift_before > 0:
                for k, data in self.buffers.items():
                    # Loop through
                    for i in range(len(data)):
                        self.buffers[k][i] = data[i][-self.shift_before:]
            self.agent_steps = 0

        # Reset our unroll_id.
        self.unroll_id = None

        return batch