Example #1
0
File: policy.py Project: alipay/ray
    def _get_dummy_batch_from_view_requirements(self,
                                                batch_size: int = 1
                                                ) -> SampleBatch:
        """Creates a numpy dummy batch based on the Policy's view requirements.

        Args:
            batch_size (int): The size of the batch to create.

        Returns:
            Dict[str, TensorType]: The dummy batch containing all zero values.
        """
        ret = {}
        for view_col, view_req in self.view_requirements.items():
            data_col = view_req.data_col or view_col
            # Flattened dummy batch.
            if (isinstance(
                    view_req.space,
                (gym.spaces.Tuple, gym.spaces.Dict))) and (
                    (data_col == SampleBatch.OBS
                     and not self.config["_disable_preprocessor_api"]) or
                    (data_col == SampleBatch.ACTIONS
                     and not self.config.get("_disable_action_flattening"))):
                _, shape = ModelCatalog.get_action_shape(
                    view_req.space, framework=self.config["framework"])
                ret[view_col] = np.zeros((batch_size, ) + shape[1:],
                                         np.float32)
            # Non-flattened dummy batch.
            else:
                # Range of indices on time-axis, e.g. "-50:-1".
                if view_req.shift_from is not None:
                    ret[view_col] = get_dummy_batch_for_space(
                        view_req.space,
                        batch_size=batch_size,
                        time_size=view_req.shift_to - view_req.shift_from + 1,
                    )
                # Sequence of (probably non-consecutive) indices.
                elif isinstance(view_req.shift, (list, tuple)):
                    ret[view_col] = get_dummy_batch_for_space(
                        view_req.space,
                        batch_size=batch_size,
                        time_size=len(view_req.shift),
                    )
                # Single shift int value.
                else:
                    if isinstance(view_req.space, gym.spaces.Space):
                        ret[view_col] = get_dummy_batch_for_space(
                            view_req.space,
                            batch_size=batch_size,
                            fill_value=0.0)
                    else:
                        ret[view_col] = [
                            view_req.space for _ in range(batch_size)
                        ]

        # Due to different view requirements for the different columns,
        # columns in the resulting batch may not all have the same batch size.
        return SampleBatch(ret)
Example #2
0
    def _initialize_loss_from_dummy_batch(
            self, auto_remove_unneeded_view_reqs: bool = True) -> None:
        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        # Fields that have not been accessed are not needed for action
        # computations -> Tag them as `used_for_compute_actions=False`.
        for key, view_req in self.view_requirements.items():
            if (not key.startswith("state_in_")
                    and key not in self._input_dict.accessed_keys):
                view_req.used_for_compute_actions = False
        for key, value in self.extra_action_out_fn().items():
            self._dummy_batch[key] = get_dummy_batch_for_space(
                gym.spaces.Box(-1.0,
                               1.0,
                               shape=value.shape.as_list()[1:],
                               dtype=value.dtype.name),
                batch_size=len(self._dummy_batch),
            )
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info(
                    "Adding extra-action-fetch `{}` to view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype.name),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(
                        -1.0,
                        1.0,
                        shape=dummy_batch[key].shape[1:],
                        dtype=dummy_batch[key].dtype,
                    ),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict),
            _is_training=True,
        )

        if self._state_inputs:
            train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
            self._loss_input_dict.update(
                {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        losses = self._do_loss_init(train_batch)

        all_accessed_keys = (train_batch.accessed_keys
                             | dummy_batch.accessed_keys
                             | dummy_batch.added_keys
                             | set(self.model.view_requirements.keys()))

        TFPolicy._initialize_loss(
            self,
            losses,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] + ([
                 (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])
             ] if SampleBatch.SEQ_LENS in train_batch else []),
        )

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        self._stats_fetches.update(self.grad_stats_fn(train_batch,
                                                      self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if (key not in train_batch.accessed_keys
                        and key not in self.model.view_requirements
                        and key not in [
                            SampleBatch.EPS_ID,
                            SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID,
                            SampleBatch.DONES,
                            SampleBatch.REWARDS,
                            SampleBatch.INFOS,
                            SampleBatch.OBS_EMBEDS,
                        ]):
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if (key not in all_accessed_keys and key not in [
                        SampleBatch.EPS_ID,
                        SampleBatch.AGENT_INDEX,
                        SampleBatch.UNROLL_ID,
                        SampleBatch.DONES,
                        SampleBatch.REWARDS,
                        SampleBatch.INFOS,
                ] and key not in self.model.view_requirements):
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    # If we are not writing output to disk, safe to erase
                    # this key to save space in the sample batch.
                    elif self.config["output"] is None:
                        del self.view_requirements[key]

                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }
Example #3
0
    def get_inference_input_dict(self,
                                 policy_id: PolicyID) -> Dict[str, TensorType]:
        policy = self.policy_map[policy_id]
        keys = self.forward_pass_agent_keys[policy_id]
        batch_size = len(keys)

        # Return empty batch, if no forward pass to do.
        if batch_size == 0:
            return SampleBatch()

        buffers = {}
        for k in keys:
            collector = self.agent_collectors[k]
            buffers[k] = collector.buffers
        # Use one agent's buffer_structs (they should all be the same).
        buffer_structs = self.agent_collectors[keys[0]].buffer_structs

        input_dict = {}
        for view_col, view_req in policy.view_requirements.items():
            # Not used for action computations.
            if not view_req.used_for_compute_actions:
                continue

            # Create the batch of data from the different buffers.
            data_col = view_req.data_col or view_col
            delta = (-1 if data_col in [
                SampleBatch.OBS,
                SampleBatch.ENV_ID,
                SampleBatch.EPS_ID,
                SampleBatch.AGENT_INDEX,
                SampleBatch.T,
            ] else 0)
            # Range of shifts, e.g. "-100:0". Note: This includes index 0!
            if view_req.shift_from is not None:
                time_indices = (view_req.shift_from + delta,
                                view_req.shift_to + delta)
            # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
            else:
                time_indices = view_req.shift + delta

            # Loop through agents and add up their data (batch).
            data = None
            for k in keys:
                # Buffer for the data does not exist yet: Create dummy
                # (zero) data.
                if data_col not in buffers[k]:
                    if view_req.data_col is not None:
                        space = policy.view_requirements[
                            view_req.data_col].space
                    else:
                        space = view_req.space

                    if isinstance(space, Space):
                        fill_value = get_dummy_batch_for_space(
                            space,
                            batch_size=0,
                        )
                    else:
                        fill_value = space

                    self.agent_collectors[k]._build_buffers(
                        {data_col: fill_value})

                if data is None:
                    data = [[] for _ in range(len(buffers[keys[0]][data_col]))]

                # `shift_from` and `shift_to` are defined: User wants a
                # view with some time-range.
                if isinstance(time_indices, tuple):
                    # `shift_to` == -1: Until the end (including(!) the
                    # last item).
                    if time_indices[1] == -1:
                        for d, b in zip(data, buffers[k][data_col]):
                            d.append(b[time_indices[0]:])
                    # `shift_to` != -1: "Normal" range.
                    else:
                        for d, b in zip(data, buffers[k][data_col]):
                            d.append(b[time_indices[0]:time_indices[1] + 1])
                # Single index.
                else:
                    for d, b in zip(data, buffers[k][data_col]):
                        d.append(b[time_indices])

            np_data = [np.array(d) for d in data]
            if data_col in buffer_structs:
                input_dict[view_col] = tree.unflatten_as(
                    buffer_structs[data_col], np_data)
            else:
                input_dict[view_col] = np_data[0]

        self._reset_inference_calls(policy_id)

        return SampleBatch(
            input_dict,
            seq_lens=np.ones(batch_size, dtype=np.int32)
            if "state_in_0" in input_dict else None,
        )