示例#1
0
    def add_postprocessed_batch_for_training(
            self, batch: SampleBatch,
            view_requirements: ViewRequirementsDict) -> None:
        """Adds a postprocessed SampleBatch (single agent) to our buffers.

        Args:
            batch (SampleBatch): An individual agent's (one trajectory)
                SampleBatch to be added to the Policy's buffers.
            view_requirements (ViewRequirementsDict): The view
                requirements for the policy. This is so we know, whether a
                view-column needs to be copied at all (not needed for
                training).
        """
        for view_col, data in batch.items():
            # 1) If col is not in view_requirements, we must have a direct
            # child of the base Policy that doesn't do auto-view req creation.
            # 2) Col is in view-reqs and needed for training.
            view_req = view_requirements.get(view_col)
            if view_req is None or view_req.used_for_training:
                self.buffers[view_col].extend(data)
        # Add the agent's trajectory length to our count.
        self.agent_steps += batch.count
        # Adjust the seq-lens array depending on the incoming agent sequences.
        if self.seq_lens is not None:
            max_seq_len = self.policy.config["model"]["max_seq_len"]
            count = batch.count
            while count > 0:
                self.seq_lens.append(min(count, max_seq_len))
                count -= max_seq_len
示例#2
0
    def add_postprocessed_batch_for_training(
            self, batch: SampleBatch,
            view_requirements: ViewRequirementsDict) -> None:
        """Adds a postprocessed SampleBatch (single agent) to our buffers.

        Args:
            batch (SampleBatch): An individual agent's (one trajectory)
                SampleBatch to be added to the Policy's buffers.
            view_requirements (ViewRequirementsDict): The view
                requirements for the policy. This is so we know, whether a
                view-column needs to be copied at all (not needed for
                training).
        """
        # Add the agent's trajectory length to our count.
        self.agent_steps += batch.count
        # And remove columns not needed for training.
        for view_col, view_req in view_requirements.items():
            if view_col in batch and not view_req.used_for_training:
                del batch[view_col]
        self.batches.append(batch)
示例#3
0
    def get_single_step_input_dict(
        self,
        view_requirements: ViewRequirementsDict,
        index: Union[str, int] = "last",
    ) -> "SampleBatch":
        """Creates single ts SampleBatch at given index from `self`.

        For usage as input-dict for model (action or value function) calls.

        Args:
            view_requirements: A view requirements dict from the model for
                which to produce the input_dict.
            index: An integer index value indicating the
                position in the trajectory for which to generate the
                compute_actions input dict. Set to "last" to generate the dict
                at the very end of the trajectory (e.g. for value estimation).
                Note that "last" is different from -1, as "last" will use the
                final NEXT_OBS as observation input.

        Returns:
            The (single-timestep) input dict for ModelV2 calls.
        """
        last_mappings = {
            SampleBatch.OBS: SampleBatch.NEXT_OBS,
            SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
            SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
        }

        input_dict = {}
        for view_col, view_req in view_requirements.items():
            if view_req.used_for_compute_actions is False:
                continue

            # Create batches of size 1 (single-agent input-dict).
            data_col = view_req.data_col or view_col
            if index == "last":
                data_col = last_mappings.get(data_col, data_col)
                # Range needed.
                if view_req.shift_from is not None:
                    # Batch repeat value > 1: We have single frames in the
                    # batch at each timestep (for the `data_col`).
                    data = self[view_col][-1]
                    traj_len = len(self[data_col])
                    missing_at_end = traj_len % view_req.batch_repeat_value
                    # Index into the observations column must be shifted by
                    # -1 b/c index=0 for observations means the current (last
                    # seen) observation (after having taken an action).
                    obs_shift = (-1 if data_col in [
                        SampleBatch.OBS, SampleBatch.NEXT_OBS
                    ] else 0)
                    from_ = view_req.shift_from + obs_shift
                    to_ = view_req.shift_to + obs_shift + 1
                    if to_ == 0:
                        to_ = None
                    input_dict[view_col] = np.array([
                        np.concatenate(
                            [data,
                             self[data_col][-missing_at_end:]])[from_:to_]
                    ])
                # Single index.
                else:
                    input_dict[view_col] = tree.map_structure(
                        lambda v: v[-1:],  # keep as array (w/ 1 element)
                        self[data_col],
                    )
            # Single index somewhere inside the trajectory (non-last).
            else:
                input_dict[view_col] = self[data_col][
                    index:index + 1 if index != -1 else None]

        return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
示例#4
0
    def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
        """Builds a SampleBatch from the thus-far collected agent data.

        If the episode/trajectory has no DONE=True at the end, will copy
        the necessary n timesteps at the end of the trajectory back to the
        beginning of the buffers and wait for new samples coming in.
        SampleBatches created by this method will be ready for postprocessing
        by a Policy.

        Args:
            view_requirements (ViewRequirementsDict): The view
                requirements dict needed to build the SampleBatch from the raw
                buffers (which may have data shifts as well as mappings from
                view-col to data-col in them).

        Returns:
            SampleBatch: The built SampleBatch for this agent, ready to go into
                postprocessing.
        """

        batch_data = {}
        np_data = {}
        for view_col, view_req in view_requirements.items():
            # Create the batch of data from the different buffers.
            data_col = view_req.data_col or view_col

            # Some columns don't exist yet (get created during postprocessing).
            # -> skip.
            if data_col not in self.buffers:
                continue

            # OBS are already shifted by -1 (the initial obs starts one ts
            # before all other data columns).
            obs_shift = -1 if data_col == SampleBatch.OBS else 0

            # Keep an np-array cache so we don't have to regenerate the
            # np-array for different view_cols using to the same data_col.
            if data_col not in np_data:
                np_data[data_col] = to_float_np_array(self.buffers[data_col])

            # Range of indices on time-axis, e.g. "-50:-1". Together with
            # the `batch_repeat_value`, this determines the data produced.
            # Example:
            #  batch_repeat_value=10, shift_from=-3, shift_to=-1
            #  buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
            #  resulting data=[[-3, -2, -1], [7, 8, 9]]
            #  Range of 3 consecutive items repeats every 10 timesteps.
            if view_req.shift_from is not None:
                # Batch repeat value > 1: Only repeat the shift_from/to range
                # every n timesteps.
                if view_req.batch_repeat_value > 1:
                    count = int(
                        math.ceil((len(np_data[data_col]) - self.shift_before)
                                  / view_req.batch_repeat_value))
                    data = np.asarray([
                        np_data[data_col][self.shift_before +
                                          (i * view_req.batch_repeat_value) +
                                          view_req.shift_from +
                                          obs_shift:self.shift_before +
                                          (i * view_req.batch_repeat_value) +
                                          view_req.shift_to + 1 + obs_shift]
                        for i in range(count)
                    ])
                # Batch repeat value = 1: Repeat the shift_from/to range at
                # each timestep.
                else:
                    d = np_data[data_col]
                    shift_win = view_req.shift_to - view_req.shift_from + 1
                    data_size = d.itemsize * int(np.product(d.shape[1:]))
                    strides = [
                        d.itemsize * int(np.product(d.shape[i + 1:]))
                        for i in range(1, len(d.shape))
                    ]
                    data = np.lib.stride_tricks.as_strided(
                        d[self.shift_before - shift_win:],
                        [self.agent_steps, shift_win
                         ] + [d.shape[i] for i in range(1, len(d.shape))],
                        [data_size, data_size] + strides)
            # Set of (probably non-consecutive) indices.
            # Example:
            #  shift=[-3, 0]
            #  buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
            #  resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
            elif isinstance(view_req.shift, np.ndarray):
                data = np_data[data_col][self.shift_before + obs_shift +
                                         view_req.shift]
            # Single shift int value. Use the trajectory as-is, and if
            # `shift` != 0: shifted by that value.
            else:
                shift = view_req.shift + obs_shift

                # Batch repeat (only provide a value every n timesteps).
                if view_req.batch_repeat_value > 1:
                    count = int(
                        math.ceil((len(np_data[data_col]) - self.shift_before)
                                  / view_req.batch_repeat_value))
                    data = np.asarray([
                        np_data[data_col][self.shift_before + (
                            i * view_req.batch_repeat_value) + shift]
                        for i in range(count)
                    ])
                # Shift is exactly 0: Use trajectory as is.
                elif shift == 0:
                    data = np_data[data_col][self.shift_before:]
                # Shift is positive: We still need to 0-pad at the end.
                elif shift > 0:
                    data = to_float_np_array(
                        self.buffers[data_col][self.shift_before + shift:] + [
                            np.zeros(
                                shape=view_req.space.shape,
                                dtype=view_req.space.dtype)
                            for _ in range(shift)
                        ])
                # Shift is negative: Shift into the already existing and
                # 0-padded "before" area of our buffers.
                else:
                    data = np_data[data_col][self.shift_before + shift:shift]

            if len(data) > 0:
                batch_data[view_col] = data

        # Due to possible batch-repeats > 1, columns in the resulting batch
        # may not all have the same batch size.
        batch = SampleBatch(batch_data, _dont_check_lens=True)

        # Add EPS_ID and UNROLL_ID to batch.
        batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
                                                   batch.count)
        if SampleBatch.UNROLL_ID not in batch.data:
            # TODO: (sven) Once we have the additional
            #  model.preprocess_train_batch in place (attention net PR), we
            #  should not even need UNROLL_ID anymore:
            #  Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
            batch.data[SampleBatch.UNROLL_ID] = np.repeat(
                _AgentCollector._next_unroll_id, batch.count)
            _AgentCollector._next_unroll_id += 1

        # This trajectory is continuing -> Copy data at the end (in the size of
        # self.shift_before) to the beginning of buffers and erase everything
        # else.
        if not self.buffers[SampleBatch.DONES][-1]:
            # Copy data to beginning of buffer and cut lists.
            if self.shift_before > 0:
                for k, data in self.buffers.items():
                    self.buffers[k] = data[-self.shift_before:]
            self.agent_steps = 0

        return batch
示例#5
0
    def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
        """Builds a SampleBatch from the thus-far collected agent data.

        If the episode/trajectory has no DONE=True at the end, will copy
        the necessary n timesteps at the end of the trajectory back to the
        beginning of the buffers and wait for new samples coming in.
        SampleBatches created by this method will be ready for postprocessing
        by a Policy.

        Args:
            view_requirements (ViewRequirementsDict): The view
                requirements dict needed to build the SampleBatch from the raw
                buffers (which may have data shifts as well as mappings from
                view-col to data-col in them).

        Returns:
            SampleBatch: The built SampleBatch for this agent, ready to go into
                postprocessing.
        """

        batch_data = {}
        np_data = {}
        for view_col, view_req in view_requirements.items():
            # Create the batch of data from the different buffers.
            data_col = view_req.data_col or view_col

            # Some columns don't exist yet (get created during postprocessing).
            # -> skip.
            if data_col not in self.buffers:
                continue

            # OBS are already shifted by -1 (the initial obs starts one ts
            # before all other data columns).
            obs_shift = -1 if data_col == SampleBatch.OBS else 0

            # Keep an np-array cache so we don't have to regenerate the
            # np-array for different view_cols using to the same data_col.
            if data_col not in np_data:
                np_data[data_col] = [
                    to_float_np_array(d) for d in self.buffers[data_col]
                ]

            # Range of indices on time-axis, e.g. "-50:-1". Together with
            # the `batch_repeat_value`, this determines the data produced.
            # Example:
            #  batch_repeat_value=10, shift_from=-3, shift_to=-1
            #  buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
            #  resulting data=[[-3, -2, -1], [7, 8, 9]]
            #  Range of 3 consecutive items repeats every 10 timesteps.
            if view_req.shift_from is not None:
                # Batch repeat value > 1: Only repeat the shift_from/to range
                # every n timesteps.
                if view_req.batch_repeat_value > 1:
                    count = int(
                        math.ceil(
                            (len(np_data[data_col][0]) - self.shift_before) /
                            view_req.batch_repeat_value))
                    data = [
                        np.asarray([
                            d[self.shift_before +
                              (i * view_req.batch_repeat_value) +
                              view_req.shift_from +
                              obs_shift:self.shift_before +
                              (i * view_req.batch_repeat_value) +
                              view_req.shift_to + 1 + obs_shift]
                            for i in range(count)
                        ]) for d in np_data[data_col]
                    ]
                # Batch repeat value = 1: Repeat the shift_from/to range at
                # each timestep.
                else:
                    d0 = np_data[data_col][0]
                    shift_win = view_req.shift_to - view_req.shift_from + 1
                    data_size = d0.itemsize * int(np.product(d0.shape[1:]))
                    strides = [
                        d0.itemsize * int(np.product(d0.shape[i + 1:]))
                        for i in range(1, len(d0.shape))
                    ]
                    start = (self.shift_before - shift_win + 1 + obs_shift +
                             view_req.shift_to)
                    data = [
                        np.lib.stride_tricks.as_strided(
                            d[start:start + self.agent_steps],
                            [self.agent_steps, shift_win] +
                            [d.shape[i] for i in range(1, len(d.shape))],
                            [data_size, data_size] + strides,
                        ) for d in np_data[data_col]
                    ]
            # Set of (probably non-consecutive) indices.
            # Example:
            #  shift=[-3, 0]
            #  buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
            #  resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
            elif isinstance(view_req.shift, np.ndarray):
                data = [
                    d[self.shift_before + obs_shift + view_req.shift]
                    for d in np_data[data_col]
                ]
            # Single shift int value. Use the trajectory as-is, and if
            # `shift` != 0: shifted by that value.
            else:
                shift = view_req.shift + obs_shift

                # Batch repeat (only provide a value every n timesteps).
                if view_req.batch_repeat_value > 1:
                    count = int(
                        math.ceil(
                            (len(np_data[data_col][0]) - self.shift_before) /
                            view_req.batch_repeat_value))
                    data = [
                        np.asarray([
                            d[self.shift_before +
                              (i * view_req.batch_repeat_value) + shift]
                            for i in range(count)
                        ]) for d in np_data[data_col]
                    ]
                # Shift is exactly 0: Use trajectory as is.
                elif shift == 0:
                    data = [d[self.shift_before:] for d in np_data[data_col]]
                # Shift is positive: We still need to 0-pad at the end.
                elif shift > 0:
                    data = [
                        to_float_np_array(
                            np.concatenate([
                                d[self.shift_before + shift:],
                                [
                                    np.zeros(
                                        shape=view_req.space.shape,
                                        dtype=view_req.space.dtype,
                                    ) for _ in range(shift)
                                ],
                            ])) for d in np_data[data_col]
                    ]
                # Shift is negative: Shift into the already existing and
                # 0-padded "before" area of our buffers.
                else:
                    data = [
                        d[self.shift_before + shift:shift]
                        for d in np_data[data_col]
                    ]

            if len(data) > 0:
                if data_col not in self.buffer_structs:
                    batch_data[view_col] = data[0]
                else:
                    batch_data[view_col] = tree.unflatten_as(
                        self.buffer_structs[data_col], data)

        # Due to possible batch-repeats > 1, columns in the resulting batch
        # may not all have the same batch size.
        batch = SampleBatch(batch_data)

        # Adjust the seq-lens array depending on the incoming agent sequences.
        if self.policy.is_recurrent():
            seq_lens = []
            max_seq_len = self.policy.config["model"]["max_seq_len"]
            count = batch.count
            while count > 0:
                seq_lens.append(min(count, max_seq_len))
                count -= max_seq_len
            batch["seq_lens"] = np.array(seq_lens)
            batch.max_seq_len = max_seq_len

        # This trajectory is continuing -> Copy data at the end (in the size of
        # self.shift_before) to the beginning of buffers and erase everything
        # else.
        if not self.buffers[SampleBatch.DONES][0][-1]:
            # Copy data to beginning of buffer and cut lists.
            if self.shift_before > 0:
                for k, data in self.buffers.items():
                    # Loop through
                    for i in range(len(data)):
                        self.buffers[k][i] = data[i][-self.shift_before:]
            self.agent_steps = 0

        # Reset our unroll_id.
        self.unroll_id = None

        return batch
示例#6
0
    def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
        """Builds a SampleBatch from the thus-far collected agent data.

        If the episode/trajectory has no DONE=True at the end, will copy
        the necessary n timesteps at the end of the trajectory back to the
        beginning of the buffers and wait for new samples coming in.
        SampleBatches created by this method will be ready for postprocessing
        by a Policy.

        Args:
            view_requirements (ViewRequirementsDict): The view
                requirements dict needed to build the SampleBatch from the raw
                buffers (which may have data shifts as well as mappings from
                view-col to data-col in them).

        Returns:
            SampleBatch: The built SampleBatch for this agent, ready to go into
                postprocessing.
        """

        batch_data = {}
        np_data = {}
        for view_col, view_req in view_requirements.items():
            # Create the batch of data from the different buffers.
            data_col = view_req.data_col or view_col

            # Some columns don't exist yet (get created during postprocessing).
            # -> skip.
            if data_col not in self.buffers:
                continue
            # OBS are already shifted by -1 (the initial obs starts one ts
            # before all other data columns).
            shift = view_req.shift - \
                (1 if data_col == SampleBatch.OBS else 0)
            if data_col not in np_data:
                np_data[data_col] = to_float_np_array(self.buffers[data_col])
            # Shift is exactly 0: Send trajectory as is.
            if shift == 0:
                data = np_data[data_col][self.shift_before:]
            # Shift is positive: We still need to 0-pad at the end here.
            elif shift > 0:
                data = to_float_np_array(
                    self.buffers[data_col][self.shift_before + shift:] + [
                        np.zeros(shape=view_req.space.shape,
                                 dtype=view_req.space.dtype)
                        for _ in range(shift)
                    ])
            # Shift is negative: Shift into the already existing and 0-padded
            # "before" area of our buffers.
            else:
                data = np_data[data_col][self.shift_before + shift:shift]
            if len(data) > 0:
                batch_data[view_col] = data

        batch = SampleBatch(batch_data)

        # Add EPS_ID and UNROLL_ID to batch.
        batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
                                                   batch.count)
        if SampleBatch.UNROLL_ID not in batch.data:
            # TODO: (sven) Once we have the additional
            #  model.preprocess_train_batch in place (attention net PR), we
            #  should not even need UNROLL_ID anymore:
            #  Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
            batch.data[SampleBatch.UNROLL_ID] = np.repeat(
                _AgentCollector._next_unroll_id, batch.count)
            _AgentCollector._next_unroll_id += 1

        # This trajectory is continuing -> Copy data at the end (in the size of
        # self.shift_before) to the beginning of buffers and erase everything
        # else.
        if not self.buffers[SampleBatch.DONES][-1]:
            # Copy data to beginning of buffer and cut lists.
            if self.shift_before > 0:
                for k, data in self.buffers.items():
                    self.buffers[k] = data[-self.shift_before:]
            self.agent_steps = 0

        return batch