コード例 #1
0
ファイル: test_sample_batch.py プロジェクト: zivzone/ray
    def test_dict_properties_of_sample_batches(self):
        base_dict = {
            "a": np.array([1, 2, 3]),
            "b": np.array([[0.1, 0.2], [0.3, 0.4]]),
            "c": True,
        }
        batch = SampleBatch(base_dict)
        try:
            SampleBatch(base_dict)
        except AssertionError:
            pass  # expected
        keys_ = list(base_dict.keys())
        values_ = list(base_dict.values())
        items_ = list(base_dict.items())
        assert list(batch.keys()) == keys_
        assert list(batch.values()) == values_
        assert list(batch.items()) == items_

        # Add an item and check, whether it's in the "added" list.
        batch["d"] = np.array(1)
        assert batch.added_keys == {"d"}, batch.added_keys
        # Access two keys and check, whether they are in the
        # "accessed" list.
        print(batch["a"], batch["b"])
        assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys
        # Delete a key and check, whether it's in the "deleted" list.
        del batch["c"]
        assert batch.deleted_keys == {"c"}, batch.deleted_keys
コード例 #2
0
    def add_postprocessed_batch_for_training(
            self, batch: SampleBatch,
            view_requirements: ViewRequirementsDict) -> None:
        """Adds a postprocessed SampleBatch (single agent) to our buffers.

        Args:
            batch (SampleBatch): An individual agent's (one trajectory)
                SampleBatch to be added to the Policy's buffers.
            view_requirements (ViewRequirementsDict): The view
                requirements for the policy. This is so we know, whether a
                view-column needs to be copied at all (not needed for
                training).
        """
        for view_col, data in batch.items():
            # 1) If col is not in view_requirements, we must have a direct
            # child of the base Policy that doesn't do auto-view req creation.
            # 2) Col is in view-reqs and needed for training.
            view_req = view_requirements.get(view_col)
            if view_req is None or view_req.used_for_training:
                self.buffers[view_col].extend(data)
        # Add the agent's trajectory length to our count.
        self.agent_steps += batch.count
        # Adjust the seq-lens array depending on the incoming agent sequences.
        if self.seq_lens is not None:
            max_seq_len = self.policy.config["model"]["max_seq_len"]
            count = batch.count
            while count > 0:
                self.seq_lens.append(min(count, max_seq_len))
                count -= max_seq_len
コード例 #3
0
    def call(
        self, input_dict: SampleBatch
    ) -> (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict[SampleBatch.SEQ_LENS] is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(
                            input_dict[SampleBatch.PREV_ACTIONS][:, i],
                            self.action_space,
                        ))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(0, self.use_n_prev_actions,
                               self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            tf.cast(
                                input_dict[SampleBatch.PREV_ACTIONS]
                                [:, i:i + self.action_space.shape[0]],
                                tf.float32,
                            ),
                            self.action_space,
                        ))
            else:
                prev_a_r.append(
                    tf.reshape(
                        tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
                                tf.float32),
                        [-1, self.use_n_prev_actions * self.action_dim],
                    ))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
                    [-1, self.use_n_prev_rewards],
                ))

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        memory_ins = [
            s for k, s in input_dict.items() if k.startswith("state_in_")
        ]
        model_out, memory_outs, value_outs = self.base_model([wrapped_out] +
                                                             memory_ins)
        return (
            model_out,
            memory_outs,
            {
                SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1])
            },
        )
コード例 #4
0
    def compute_actions_from_input_dict(
            self,
            input_dict: SampleBatch,
            explore: bool = None,
            timestep: Optional[int] = None,
            episodes: Optional[List["MultiAgentEpisode"]] = None,
            **kwargs) -> \
            Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
        """Computes actions from collected samples (across multiple-agents).

        Uses the currently "forward-pass-registered" samples from the collector
        to construct the input_dict for the Model.

        Args:
            input_dict (SampleBatch): A SampleBatch containing the Tensors
                to compute actions. `input_dict` already abides to the
                Policy's as well as the Model's view requirements and can
                thus be passed to the Model as-is.
            explore (bool): Whether to pick an exploitation or exploration
                action (default: None -> use self.config["explore"]).
            timestep (Optional[int]): The current (sampling) time step.
            kwargs: forward compatibility placeholder

        Returns:
            Tuple:
                actions (TensorType): Batch of output actions, with shape
                    like [BATCH_SIZE, ACTION_SHAPE].
                state_outs (List[TensorType]): List of RNN state output
                    batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
                info (dict): Dictionary of extra feature batches, if any, with
                    shape like
                    {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
        """
        # Default implementation just passes obs, prev-a/r, and states on to
        # `self.compute_actions()`.
        state_batches = [
            s for k, s in input_dict.items() if k[:9] == "state_in_"
        ]
        return self.compute_actions(
            input_dict[SampleBatch.OBS],
            state_batches,
            prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
            info_batch=input_dict.get(SampleBatch.INFOS),
            explore=explore,
            timestep=timestep,
            episodes=episodes,
            **kwargs,
        )
コード例 #5
0
ファイル: visionnet.py プロジェクト: stjordanis/ray
 def call(self, input_dict: SampleBatch) -> \
         (TensorType, List[TensorType], Dict[str, TensorType]):
     obs = input_dict["obs"]
     if self.data_format == "channels_first":
         obs = tf.transpose(obs, [0, 2, 3, 1])
     # Explicit cast to float32 needed in eager.
     model_out, self._value_out = self.base_model(tf.cast(obs, tf.float32))
     state = [v for k, v in input_dict.items() if k.startswith("state_in_")]
     extra_outs = {SampleBatch.VF_PREDS: tf.reshape(self._value_out, [-1])}
     # Our last layer is already flat.
     if self.last_layer_is_flattened:
         return model_out, state, extra_outs
     # Last layer is a n x [1,1] Conv2D -> Flatten.
     else:
         return tf.squeeze(model_out, axis=[1, 2]), state, extra_outs
コード例 #6
0
    def add_postprocessed_batch_for_training(
            self, batch: SampleBatch,
            view_requirements: Dict[str, ViewRequirement]) -> None:
        """Adds a postprocessed SampleBatch (single agent) to our buffers.

        Args:
            batch (SampleBatch): A single agent (one trajectory) SampleBatch
                to be added to the Policy's buffers.
            view_requirements (Dict[str, ViewRequirement]: The view
                requirements for the policy. This is so we know, whether a
                view-column needs to be copied at all (not needed for
                training).
        """
        for view_col, data in batch.items():
            # Skip columns that are not used for training.
            if view_col not in view_requirements or \
                    not view_requirements[view_col].used_for_training:
                continue
            self.buffers[view_col].extend(data)
        # Add the agent's trajectory length to our count.
        self.count += batch.count
コード例 #7
0
    def add_postprocessed_batch_for_training(
            self, batch: SampleBatch,
            view_requirements: ViewRequirementsDict) -> None:
        """Adds a postprocessed SampleBatch (single agent) to our buffers.

        Args:
            batch (SampleBatch): An individual agent's (one trajectory)
                SampleBatch to be added to the Policy's buffers.
            view_requirements (ViewRequirementsDict): The view
                requirements for the policy. This is so we know, whether a
                view-column needs to be copied at all (not needed for
                training).
        """
        for view_col, data in batch.items():
            # 1) If col is not in view_requirements, we must have a direct
            # child of the base Policy that doesn't do auto-view req creation.
            # 2) Col is in view-reqs and needed for training.
            if view_col not in view_requirements or \
                    view_requirements[view_col].used_for_training:
                self.buffers[view_col].extend(data)
        # Add the agent's trajectory length to our count.
        self.agent_steps += batch.count
コード例 #8
0
ファイル: rnn_sequencing.py プロジェクト: zhangbushi10/ray
def pad_batch_to_sequences_of_same_size(
    batch: SampleBatch,
    max_seq_len: int,
    shuffle: bool = False,
    batch_divisibility_req: int = 1,
    feature_keys: Optional[List[str]] = None,
    view_requirements: Optional[ViewRequirementsDict] = None,
):
    """Applies padding to `batch` so it's choppable into same-size sequences.

    Shuffles `batch` (if desired), makes sure divisibility requirement is met,
    then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
    adding a time dimension (yet).
    Padding depends on episodes found in batch and `max_seq_len`.

    Args:
        batch (SampleBatch): The SampleBatch object. All values in here have
            the shape [B, ...].
        max_seq_len (int): The max. sequence length to use for chopping.
        shuffle (bool): Whether to shuffle batch sequences. Shuffle may
            be done in-place. This only makes sense if you're further
            applying minibatch SGD after getting the outputs.
        batch_divisibility_req (int): The int by which the batch dimension
            must be dividable.
        feature_keys (Optional[List[str]]): An optional list of keys to apply
            sequence-chopping to. If None, use all keys in batch that are not
            "state_in/out_"-type keys.
        view_requirements (Optional[ViewRequirementsDict]): An optional
            Policy ViewRequirements dict to be able to infer whether
            e.g. dynamic max'ing should be applied over the seq_lens.
    """
    if batch_divisibility_req > 1:
        meets_divisibility_reqs = (
            len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
            # not multiagent
            and max(batch[SampleBatch.AGENT_INDEX]) == 0)
    else:
        meets_divisibility_reqs = True

    states_already_reduced_to_init = False

    # RNN/attention net case. Figure out whether we should apply dynamic
    # max'ing over the list of sequence lengths.
    if "state_in_0" in batch or "state_out_0" in batch:
        # Check, whether the state inputs have already been reduced to their
        # init values at the beginning of each max_seq_len chunk.
        if batch.seq_lens is not None and \
                len(batch["state_in_0"]) == len(batch.seq_lens):
            states_already_reduced_to_init = True

        # RNN (or single timestep state-in): Set the max dynamically.
        if view_requirements["state_in_0"].shift_from is None:
            dynamic_max = True
        # Attention Nets (state inputs are over some range): No dynamic maxing
        # possible.
        else:
            dynamic_max = False
    # Multi-agent case.
    elif not meets_divisibility_reqs:
        max_seq_len = batch_divisibility_req
        dynamic_max = False
    # Simple case: No RNN/attention net, nor do we need to pad.
    else:
        if shuffle:
            batch.shuffle()
        return

    # RNN, attention net, or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k, v in batch.items():
        if k.startswith("state_in_"):
            state_keys.append(k)
        elif not feature_keys and not k.startswith("state_out_") and \
                k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray):
            feature_keys_.append(k)

    feature_sequences, initial_states, seq_lens = \
        chop_into_sequences(
            feature_columns=[batch[k] for k in feature_keys_],
            state_columns=[batch[k] for k in state_keys],
            episode_ids=batch.get(SampleBatch.EPS_ID),
            unroll_ids=batch.get(SampleBatch.UNROLL_ID),
            agent_indices=batch.get(SampleBatch.AGENT_INDEX),
            seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")),
            max_seq_len=max_seq_len,
            dynamic_max=dynamic_max,
            states_already_reduced_to_init=states_already_reduced_to_init,
            shuffle=shuffle)

    for i, k in enumerate(feature_keys_):
        batch[k] = feature_sequences[i]
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch["seq_lens"] = np.array(seq_lens)

    if log_once("rnn_ma_feed_dict"):
        logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
            summarize({
                "features": feature_sequences,
                "initial_states": initial_states,
                "seq_lens": seq_lens,
                "max_seq_len": max_seq_len,
            })))
コード例 #9
0
    def _initialize_loss_from_dummy_batch(
            self, auto_remove_unneeded_view_reqs: bool = True) -> None:
        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        # Fields that have not been accessed are not needed for action
        # computations -> Tag them as `used_for_compute_actions=False`.
        for key, view_req in self.view_requirements.items():
            if (not key.startswith("state_in_")
                    and key not in self._input_dict.accessed_keys):
                view_req.used_for_compute_actions = False
        for key, value in self.extra_action_out_fn().items():
            self._dummy_batch[key] = get_dummy_batch_for_space(
                gym.spaces.Box(-1.0,
                               1.0,
                               shape=value.shape.as_list()[1:],
                               dtype=value.dtype.name),
                batch_size=len(self._dummy_batch),
            )
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info(
                    "Adding extra-action-fetch `{}` to view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype.name),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(
                        -1.0,
                        1.0,
                        shape=dummy_batch[key].shape[1:],
                        dtype=dummy_batch[key].dtype,
                    ),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict),
            _is_training=True,
        )

        if self._state_inputs:
            train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
            self._loss_input_dict.update(
                {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        losses = self._do_loss_init(train_batch)

        all_accessed_keys = (train_batch.accessed_keys
                             | dummy_batch.accessed_keys
                             | dummy_batch.added_keys
                             | set(self.model.view_requirements.keys()))

        TFPolicy._initialize_loss(
            self,
            losses,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] + ([
                 (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])
             ] if SampleBatch.SEQ_LENS in train_batch else []),
        )

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        self._stats_fetches.update(self.grad_stats_fn(train_batch,
                                                      self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if (key not in train_batch.accessed_keys
                        and key not in self.model.view_requirements
                        and key not in [
                            SampleBatch.EPS_ID,
                            SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID,
                            SampleBatch.DONES,
                            SampleBatch.REWARDS,
                            SampleBatch.INFOS,
                            SampleBatch.OBS_EMBEDS,
                        ]):
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if (key not in all_accessed_keys and key not in [
                        SampleBatch.EPS_ID,
                        SampleBatch.AGENT_INDEX,
                        SampleBatch.UNROLL_ID,
                        SampleBatch.DONES,
                        SampleBatch.REWARDS,
                        SampleBatch.INFOS,
                ] and key not in self.model.view_requirements):
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    # If we are not writing output to disk, safe to erase
                    # this key to save space in the sample batch.
                    elif self.config["output"] is None:
                        del self.view_requirements[key]

                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }
コード例 #10
0
    def add_batch(self, batch: SampleBatch) -> None:
        """Add the given batch of values to this batch."""

        for k, column in batch.items():
            self.buffers[k].extend(column)
        self.count += batch.count
コード例 #11
0
    def _initialize_loss_from_dummy_batch(
            self,
            auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        logger.info("Testing `compute_actions` w/ dummy batch.")
        actions, state_outs, extra_fetches = \
            self.compute_actions_from_input_dict(
                self._dummy_batch, explore=False, timestep=0)
        for key, value in extra_fetches.items():
            self._dummy_batch[key] = value
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info("Adding extra-action-fetch `{}` to "
                            "view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=dummy_batch[key].shape[1:],
                                         dtype=dummy_batch[key].dtype),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens
            self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | dummy_batch.accessed_keys | \
            dummy_batch.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(
            self, loss,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] +
            ([("seq_lens",
               train_batch["seq_lens"])] if "seq_lens" in train_batch else []))

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self.get_session().run(tf1.global_variables_initializer())
コード例 #12
0
    def _initialize_loss_from_dummy_batch(
            self, auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self._sess.run(tf1.global_variables_initializer())

        if self.config["_use_trajectory_view_api"]:
            logger.info("Testing `compute_actions` w/ dummy batch.")
            actions, state_outs, extra_fetches = \
                self.compute_actions_from_input_dict(
                    self._dummy_batch, explore=False, timestep=0)
            for key, value in extra_fetches.items():
                self._dummy_batch[key] = np.zeros_like(value)
                self._input_dict[key] = get_placeholder(value=value, name=key)
                if key not in self.view_requirements:
                    logger.info("Adding extra-action-fetch `{}` to "
                                "view-reqs.".format(key))
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=value.shape[1:],
                            dtype=value.dtype))
            dummy_batch = self._dummy_batch
        else:

            def fake_array(tensor):
                shape = tensor.shape.as_list()
                shape = [s if s is not None else 1 for s in shape]
                return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

            dummy_batch = {
                SampleBatch.CUR_OBS: fake_array(self._obs_input),
                SampleBatch.NEXT_OBS: fake_array(self._obs_input),
                SampleBatch.DONES: np.array([False], dtype=np.bool),
                SampleBatch.ACTIONS: fake_array(
                    ModelCatalog.get_action_placeholder(self.action_space)),
                SampleBatch.REWARDS: np.array([0], dtype=np.float32),
            }
            if self._obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS: fake_array(
                        self._prev_action_input),
                    SampleBatch.PREV_REWARDS: fake_array(
                        self._prev_reward_input),
                })
            state_init = self.get_initial_state()
            state_batches = []
            for i, h in enumerate(state_init):
                dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
                dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
                state_batches.append(np.expand_dims(h, 0))
            if state_init:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
            for k, v in self.extra_compute_action_fetches().items():
                dummy_batch[k] = fake_array(v)
            dummy_batch = SampleBatch(dummy_batch)

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch, self._sess)
        postprocessed_batch = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        if self.config["_use_trajectory_view_api"]:
            for key in dummy_batch.added_keys:
                if key not in self._input_dict:
                    self._input_dict[key] = get_placeholder(
                        value=dummy_batch[key], name=key)
                if key not in self.view_requirements:
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=dummy_batch[key].shape[1:],
                            dtype=dummy_batch[key].dtype))

        if not self.config["_use_trajectory_view_api"]:
            train_batch = SampleBatch(
                dict({
                    SampleBatch.CUR_OBS: self._obs_input,
                }, **self._loss_input_dict))
            if self._obs_include_prev_action_reward:
                train_batch.update({
                    SampleBatch.PREV_ACTIONS: self._prev_action_input,
                    SampleBatch.PREV_REWARDS: self._prev_reward_input,
                    SampleBatch.CUR_OBS: self._obs_input,
                })

            for k, v in postprocessed_batch.items():
                if k in train_batch:
                    continue
                elif v.dtype == np.object:
                    continue  # can't handle arbitrary objects in TF
                elif k == "seq_lens" or k.startswith("state_in_"):
                    continue
                shape = (None, ) + v.shape[1:]
                dtype = np.float32 if v.dtype == np.float64 else v.dtype
                placeholder = tf1.placeholder(dtype, shape=shape, name=k)
                train_batch[k] = placeholder

            for i, si in enumerate(self._state_inputs):
                train_batch["state_in_{}".format(i)] = si
        else:
            train_batch = SampleBatch(
                dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})
        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | dummy_batch.accessed_keys | \
            dummy_batch.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(self, loss, [(k, v)
                                               for k, v in train_batch.items()
                                               if k in all_accessed_keys])

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if self.config["_use_trajectory_view_api"] and \
                auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self._sess.run(tf1.global_variables_initializer())
コード例 #13
0
def compute_advantages(rollout: SampleBatch,
                       last_r: float,
                       gamma: float = 0.9,
                       lambda_: float = 1.0,
                       use_gae: bool = True,
                       use_critic: bool = True):
    """
    Given a rollout, compute its value targets and the advantage.

    Args:
        rollout (SampleBatch): SampleBatch of a single trajectory
        last_r (float): Value estimation for last observation
        gamma (float): Discount factor.
        lambda_ (float): Parameter for GAE
        use_gae (bool): Using Generalized Advantage Estimation
        use_critic (bool): Whether to use critic (value estimates). Setting
                           this to False will use 0 as baseline.

    Returns:
        SampleBatch (SampleBatch): Object with experience from rollout and
            processed rewards.
    """

    rollout_size = len(rollout[SampleBatch.ACTIONS])

    assert SampleBatch.VF_PREDS in rollout or not use_critic, \
        "use_critic=True but values not found"
    assert use_critic or not use_gae, \
        "Can't use gae without using a value function"

    if use_gae:
        vpred_t = np.concatenate(
            [rollout[SampleBatch.VF_PREDS],
             np.array([last_r])])
        delta_t = (rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] -
                   vpred_t[:-1])
        # This formula for the advantage comes from:
        # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
        rollout[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_)
        rollout[Postprocessing.VALUE_TARGETS] = (
            rollout[Postprocessing.ADVANTAGES] +
            rollout[SampleBatch.VF_PREDS]).copy().astype(np.float32)
    else:
        rewards_plus_v = np.concatenate(
            [rollout[SampleBatch.REWARDS],
             np.array([last_r])])
        discounted_returns = discount(rewards_plus_v,
                                      gamma)[:-1].copy().astype(np.float32)

        if use_critic:
            rollout[Postprocessing.ADVANTAGES] = discounted_returns - rollout[
                SampleBatch.VF_PREDS]
            rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
        else:
            rollout[Postprocessing.ADVANTAGES] = discounted_returns
            rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
                rollout[Postprocessing.ADVANTAGES])

    rollout[Postprocessing.ADVANTAGES] = rollout[
        Postprocessing.ADVANTAGES].copy().astype(np.float32)

    assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \
        "Rollout stacked incorrectly!"
    return rollout