Пример #1
0
 def compute_gradients(
     self, postprocessed_batch: SampleBatch
 ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
     assert self.loss_initialized()
     # Switch on is_training flag in our batch.
     postprocessed_batch.set_training(True)
     builder = _TFRunBuilder(self.get_session(), "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Пример #2
0
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorType], TensorType],
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        prev_reward_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        actions_normalized: bool = True,
    ) -> TensorType:

        if self._log_likelihood is None:
            raise ValueError(
                "Cannot compute log-prob/likelihood w/o a self._log_likelihood op!"
            )

        # Exploration hook before each forward pass.
        self.exploration.before_compute_actions(explore=False,
                                                tf_sess=self.get_session())

        builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods")

        # Normalize actions if necessary.
        if actions_normalized is False and self.config["normalize_actions"]:
            actions = normalize_action(actions, self.action_space_struct)

        # Feed actions (for which we want logp values) into graph.
        builder.add_feed_dict({self._action_input: actions})
        # Feed observations.
        builder.add_feed_dict({self._obs_input: obs_batch})
        # Internal states.
        state_batches = state_batches or []
        if len(self._state_inputs) != len(state_batches):
            raise ValueError(
                "Must pass in RNN state batches for placeholders {}, got {}".
                format(self._state_inputs, state_batches))
        builder.add_feed_dict(
            {k: v
             for k, v in zip(self._state_inputs, state_batches)})
        if state_batches:
            builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
        # Prev-a and r.
        if self._prev_action_input is not None and prev_action_batch is not None:
            builder.add_feed_dict({self._prev_action_input: prev_action_batch})
        if self._prev_reward_input is not None and prev_reward_batch is not None:
            builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
        # Fetch the log_likelihoods output and return.
        fetches = builder.add_fetches([self._log_likelihood])
        return builder.get(fetches)[0]
Пример #3
0
    def compute_actions(
        self,
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorType], TensorType] = None,
        prev_reward_batch: Union[List[TensorType], TensorType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs,
    ):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        builder = _TFRunBuilder(self.get_session(), "compute_actions")

        input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
        if state_batches:
            for i, s in enumerate(state_batches):
                input_dict[f"state_in_{i}"] = s
        if prev_action_batch is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
        if prev_reward_batch is not None:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch

        to_fetch = self._build_compute_actions(builder,
                                               input_dict=input_dict,
                                               explore=explore,
                                               timestep=timestep)

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += (len(obs_batch) if isinstance(obs_batch, list)
                                 else tree.flatten(obs_batch)[0].shape[0])

        return fetched
Пример #4
0
    def learn_on_batch(
            self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
        assert self.loss_initialized()

        # Switch on is_training flag in our batch.
        postprocessed_batch.set_training(True)

        builder = _TFRunBuilder(self.get_session(), "learn_on_batch")

        # Callback handling.
        learn_stats = {}
        self.callbacks.on_learn_on_batch(policy=self,
                                         train_batch=postprocessed_batch,
                                         result=learn_stats)

        fetches = self._build_learn_on_batch(builder, postprocessed_batch)
        stats = builder.get(fetches)
        stats.update({
            "custom_metrics": learn_stats,
            NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
        })
        return stats
Пример #5
0
    def compute_actions_from_input_dict(
        self,
        input_dict: Union[SampleBatch, Dict[str, TensorType]],
        explore: bool = None,
        timestep: Optional[int] = None,
        episodes: Optional[List["Episode"]] = None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        # Switch off is_training flag in our batch.
        if isinstance(input_dict, SampleBatch):
            input_dict.set_training(False)
        else:
            # Deprecated dict input.
            input_dict["is_training"] = False

        builder = _TFRunBuilder(self.get_session(),
                                "compute_actions_from_input_dict")
        obs_batch = input_dict[SampleBatch.OBS]
        to_fetch = self._build_compute_actions(builder,
                                               input_dict=input_dict,
                                               explore=explore,
                                               timestep=timestep)

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += (len(obs_batch) if isinstance(
            obs_batch, list) else len(input_dict) if isinstance(
                input_dict, SampleBatch) else obs_batch.shape[0])

        return fetched
Пример #6
0
 def apply_gradients(self, gradients: ModelGradients) -> None:
     assert self.loss_initialized()
     builder = _TFRunBuilder(self.get_session(), "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     builder.get(fetches)