Beispiel #1
0
    def _get_dummy_batch_from_view_requirements(self,
                                                batch_size: int = 1
                                                ) -> SampleBatch:
        """Creates a numpy dummy batch based on the Policy's view requirements.

        Args:
            batch_size (int): The size of the batch to create.

        Returns:
            Dict[str, TensorType]: The dummy batch containing all zero values.
        """
        ret = {}
        for view_col, view_req in self.view_requirements.items():
            data_col = view_req.data_col or view_col
            # Flattened dummy batch.
            if (isinstance(
                    view_req.space,
                (gym.spaces.Tuple, gym.spaces.Dict))) and (
                    (data_col == SampleBatch.OBS
                     and not self.config["_disable_preprocessor_api"]) or
                    (data_col == SampleBatch.ACTIONS
                     and not self.config.get("_disable_action_flattening"))):
                _, shape = ModelCatalog.get_action_shape(
                    view_req.space, framework=self.config["framework"])
                ret[view_col] = np.zeros((batch_size, ) + shape[1:],
                                         np.float32)
            # Non-flattened dummy batch.
            else:
                # Range of indices on time-axis, e.g. "-50:-1".
                if view_req.shift_from is not None:
                    ret[view_col] = get_dummy_batch_for_space(
                        view_req.space,
                        batch_size=batch_size,
                        time_size=view_req.shift_to - view_req.shift_from + 1,
                    )
                # Sequence of (probably non-consecutive) indices.
                elif isinstance(view_req.shift, (list, tuple)):
                    ret[view_col] = get_dummy_batch_for_space(
                        view_req.space,
                        batch_size=batch_size,
                        time_size=len(view_req.shift),
                    )
                # Single shift int value.
                else:
                    if isinstance(view_req.space, gym.spaces.Space):
                        ret[view_col] = get_dummy_batch_for_space(
                            view_req.space,
                            batch_size=batch_size,
                            fill_value=0.0)
                    else:
                        ret[view_col] = [
                            view_req.space for _ in range(batch_size)
                        ]

        # Due to different view requirements for the different columns,
        # columns in the resulting batch may not all have the same batch size.
        return SampleBatch(ret)
Beispiel #2
0
    def _get_dummy_batch_from_view_requirements(self,
                                                batch_size: int = 1
                                                ) -> SampleBatch:
        """Creates a numpy dummy batch based on the Policy's view requirements.

        Args:
            batch_size (int): The size of the batch to create.

        Returns:
            Dict[str, TensorType]: The dummy batch containing all zero values.
        """
        ret = {}
        for view_col, view_req in self.view_requirements.items():
            if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)):
                _, shape = ModelCatalog.get_action_shape(
                    view_req.space, framework=self.config["framework"])
                ret[view_col] = \
                    np.zeros((batch_size, ) + shape[1:], np.float32)
            else:
                # Range of indices on time-axis, e.g. "-50:-1".
                if view_req.shift_from is not None:
                    ret[view_col] = np.zeros_like([[
                        view_req.space.sample()
                        for _ in range(view_req.shift_to -
                                       view_req.shift_from + 1)
                    ] for _ in range(batch_size)])
                # Set of (probably non-consecutive) indices.
                elif isinstance(view_req.shift, (list, tuple)):
                    ret[view_col] = np.zeros_like([[
                        view_req.space.sample()
                        for t in range(len(view_req.shift))
                    ] for _ in range(batch_size)])
                # Single shift int value.
                else:
                    if isinstance(view_req.space, gym.spaces.Space):
                        ret[view_col] = np.zeros_like([
                            view_req.space.sample() for _ in range(batch_size)
                        ])
                    else:
                        ret[view_col] = [
                            view_req.space for _ in range(batch_size)
                        ]

        # Due to different view requirements for the different columns,
        # columns in the resulting batch may not all have the same batch size.
        return SampleBatch(ret)
Beispiel #3
0
    def _get_dummy_batch_from_view_requirements(
            self, batch_size: int = 1) -> SampleBatch:
        """Creates a numpy dummy batch based on the Policy's view requirements.

        Args:
            batch_size (int): The size of the batch to create.

        Returns:
            Dict[str, TensorType]: The dummy batch containing all zero values.
        """
        ret = {}
        for view_col, view_req in self.view_requirements.items():
            if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)):
                _, shape = ModelCatalog.get_action_shape(view_req.space)
                ret[view_col] = \
                    np.zeros((batch_size, ) + shape[1:], np.float32)
            else:
                ret[view_col] = np.zeros_like(
                    [view_req.space.sample() for _ in range(batch_size)])
        return SampleBatch(ret)
Beispiel #4
0
        def _initialize_loss_with_dummy_batch(self):
            # Dummy forward pass to initialize any policy attributes, etc.
            action_dtype, action_shape = ModelCatalog.get_action_shape(
                self.action_space)
            dummy_batch = {
                SampleBatch.CUR_OBS:
                np.array([self.observation_space.sample()]),
                SampleBatch.NEXT_OBS:
                np.array([self.observation_space.sample()]),
                SampleBatch.DONES:
                np.array([False], dtype=np.bool),
                SampleBatch.ACTIONS:
                tf.nest.map_structure(lambda c: np.array([c]),
                                      self.action_space.sample()),
                SampleBatch.REWARDS:
                np.array([0], dtype=np.float32),
            }
            if obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    dummy_batch[SampleBatch.ACTIONS],
                    SampleBatch.PREV_REWARDS:
                    dummy_batch[SampleBatch.REWARDS],
                })
            for i, h in enumerate(self._state_in):
                dummy_batch["state_in_{}".format(i)] = h
                dummy_batch["state_out_{}".format(i)] = h

            if self._state_in:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)

            # Convert everything to tensors.
            dummy_batch = tf.nest.map_structure(tf.convert_to_tensor,
                                                dummy_batch)

            # for IMPALA which expects a certain sample batch size.
            def tile_to(tensor, n):
                return tf.tile(tensor,
                               [n] + [1 for _ in tensor.shape.as_list()[1:]])

            if get_batch_divisibility_req:
                dummy_batch = tf.nest.map_structure(
                    lambda c: tile_to(c, get_batch_divisibility_req(self)),
                    dummy_batch)

            # Execute a forward pass to get self.action_dist etc initialized,
            # and also obtain the extra action fetches
            _, _, fetches = self.compute_actions(
                dummy_batch[SampleBatch.CUR_OBS], self._state_in,
                dummy_batch.get(SampleBatch.PREV_ACTIONS),
                dummy_batch.get(SampleBatch.PREV_REWARDS))
            dummy_batch.update(fetches)

            postprocessed_batch = self.postprocess_trajectory(
                SampleBatch(dummy_batch))

            # model forward pass for the loss (needed after postprocess to
            # overwrite any tensor state from that call)
            self.model.from_batch(dummy_batch)

            postprocessed_batch = tf.nest.map_structure(
                lambda c: tf.convert_to_tensor(c), postprocessed_batch.data)

            loss_fn(self, self.model, self.dist_class, postprocessed_batch)
            if stats_fn:
                stats_fn(self, postprocessed_batch)
Beispiel #5
0
        def _initialize_loss_with_dummy_batch(self):
            # Dummy forward pass to initialize any policy attributes, etc.
            action_dtype, action_shape = ModelCatalog.get_action_shape(
                self.action_space)
            dummy_batch = {
                SampleBatch.CUR_OBS:
                tf.convert_to_tensor(
                    np.array([self.observation_space.sample()])),
                SampleBatch.NEXT_OBS:
                tf.convert_to_tensor(
                    np.array([self.observation_space.sample()])),
                SampleBatch.DONES:
                tf.convert_to_tensor(np.array([False], dtype=np.bool)),
                SampleBatch.ACTIONS:
                tf.convert_to_tensor(
                    np.zeros((1, ) + action_shape[1:],
                             dtype=action_dtype.as_numpy_dtype())),
                SampleBatch.REWARDS:
                tf.convert_to_tensor(np.array([0], dtype=np.float32)),
            }
            if obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    dummy_batch[SampleBatch.ACTIONS],
                    SampleBatch.PREV_REWARDS:
                    dummy_batch[SampleBatch.REWARDS],
                })
            state_init = self.get_initial_state()
            state_batches = []
            for i, h in enumerate(state_init):
                dummy_batch["state_in_{}".format(i)] = tf.convert_to_tensor(
                    np.expand_dims(h, 0))
                dummy_batch["state_out_{}".format(i)] = tf.convert_to_tensor(
                    np.expand_dims(h, 0))
                state_batches.append(tf.convert_to_tensor(np.expand_dims(h,
                                                                         0)))
            if state_init:
                dummy_batch["seq_lens"] = tf.convert_to_tensor(
                    np.array([1], dtype=np.int32))

            # for IMPALA which expects a certain sample batch size
            def tile_to(tensor, n):
                return tf.tile(tensor,
                               [n] + [1 for _ in tensor.shape.as_list()[1:]])

            if get_batch_divisibility_req:
                dummy_batch = {
                    k: tile_to(v, get_batch_divisibility_req(self))
                    for k, v in dummy_batch.items()
                }

            # Execute a forward pass to get self.action_dist etc initialized,
            # and also obtain the extra action fetches
            _, _, fetches = self.compute_actions(
                dummy_batch[SampleBatch.CUR_OBS], state_batches,
                dummy_batch.get(SampleBatch.PREV_ACTIONS),
                dummy_batch.get(SampleBatch.PREV_REWARDS))
            dummy_batch.update(fetches)

            postprocessed_batch = self.postprocess_trajectory(
                SampleBatch(dummy_batch))

            # model forward pass for the loss (needed after postprocess to
            # overwrite any tensor state from that call)
            self.model.from_batch(dummy_batch)

            postprocessed_batch = {
                k: tf.convert_to_tensor(v)
                for k, v in postprocessed_batch.items()
            }

            loss_fn(self, self.model, self.dist_class, postprocessed_batch)
            if stats_fn:
                stats_fn(self, postprocessed_batch)