Пример #1
0
    def compute_actions(self,
                        obs_batch: Union[List[TensorType], TensorType],
                        state_batches: Optional[List[TensorType]] = None,
                        prev_action_batch: Union[List[TensorType],
                                                 TensorType] = None,
                        prev_reward_batch: Union[List[TensorType],
                                                 TensorType] = None,
                        info_batch: Optional[Dict[str, list]] = None,
                        episodes: Optional[List["MultiAgentEpisode"]] = None,
                        explore: Optional[bool] = None,
                        timestep: Optional[int] = None,
                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        builder = TFRunBuilder(self._sess, "compute_actions")
        to_fetch = self._build_compute_actions(
            builder,
            obs_batch,
            state_batches=state_batches,
            prev_action_batch=prev_action_batch,
            prev_reward_batch=prev_reward_batch,
            explore=explore,
            timestep=timestep)

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
            else obs_batch.shape[0]

        return fetched
Пример #2
0
 def compute_gradients(self, samples):
     if log_once("compute_gradients"):
         logger.info("Compute gradients on:\n\n{}\n".format(
             summarize(samples)))
     if isinstance(samples, MultiAgentBatch):
         grad_out, info_out = {}, {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_gradients")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid]._build_compute_gradients(
                         builder, batch))
             grad_out = {k: builder.get(v) for k, v in grad_out.items()}
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].compute_gradients(batch))
     else:
         grad_out, info_out = (
             self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
     info_out["batch_count"] = samples.count
     if log_once("grad_out"):
         logger.info("Compute grad info:\n\n{}\n".format(
             summarize(info_out)))
     return grad_out, info_out
Пример #3
0
 def learn_on_batch(self, samples):
     if log_once("learn_on_batch"):
         logger.info(
             "Training on concatenated sample batches:\n\n{}\n".format(
                 summarize(samples)))
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         to_fetch = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
         else:
             builder = None
         for pid, batch in samples.policy_batches.items():
             if pid not in self.policies_to_train:
                 continue
             policy = self.policy_map[pid]
             if builder and hasattr(policy, "_build_learn_on_batch"):
                 to_fetch[pid] = policy._build_learn_on_batch(
                     builder, batch)
             else:
                 info_out[pid] = policy.learn_on_batch(batch)
         info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
     else:
         info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
             samples)
     if log_once("learn_out"):
         logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
     return info_out
Пример #4
0
    def apply_gradients(self, grads):
        """Applies the given gradients to this worker's weights.

        Examples:
            >>> samples = worker.sample()
            >>> grads, info = worker.compute_gradients(samples)
            >>> worker.apply_gradients(grads)
        """
        if log_once("apply_gradients"):
            logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
        if isinstance(grads, dict):
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "apply_gradients")
                outputs = {
                    pid:
                    self.policy_map[pid]._build_apply_gradients(builder, grad)
                    for pid, grad in grads.items()
                }
                return {k: builder.get(v) for k, v in outputs.items()}
            else:
                return {
                    pid: self.policy_map[pid].apply_gradients(g)
                    for pid, g in grads.items()
                }
        else:
            return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Пример #5
0
    def compute_actions_from_input_dict(
            self,
            input_dict: Dict[str, TensorType],
            explore: bool = None,
            timestep: Optional[int] = None,
            episodes: Optional[List["MultiAgentEpisode"]] = None,
            **kwargs) -> \
            Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict")
        obs_batch = input_dict[SampleBatch.OBS]
        to_fetch = self._build_compute_actions(builder,
                                               input_dict=input_dict,
                                               explore=explore,
                                               timestep=timestep)

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
            else obs_batch.shape[0]

        return fetched
Пример #6
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in = [t.rnn_state for t in eval_data]
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):
            rnn_in_cols = _to_column_format(rnn_in)
            # TODO(ekl): how can we make info batch available to TF code?
            # TODO(sven): Return dict from _build_compute_actions.
            # it's becoming more and more unclear otherwise, what's where in
            # the return tuple.
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=[t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                timestep=policy.global_timestep)
        else:
            # TODO(sven): Does this work for LSTM torch?
            rnn_in_cols = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Пример #7
0
 def compute_actions(self,
                     obs_batch,
                     state_batches=None,
                     is_training=False):
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self.build_compute_actions(builder, obs_batch, state_batches,
                                          is_training)
     return builder.get(fetches)
Пример #8
0
 def compute_gradients(
         self,
         postprocessed_batch: SampleBatch) -> \
         Tuple[ModelGradients, Dict[str, TensorType]]:
     assert self.loss_initialized()
     builder = TFRunBuilder(self.get_session(), "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Пример #9
0
 def compute_rnn_state_out(self):
     if self._episode_buffer:
         if not self._rnn_state_out:
             builder = TFRunBuilder(self._sess, "compute_rnn_state_out")
             fetches = self.build_compute_rnn_state_out(builder)
             self._rnn_state_out = builder.get(fetches)
         return self._rnn_state_out
     else:
         return self.model.rnn_state_out_init
Пример #10
0
 def compute_gradients(
     self, postprocessed_batch: SampleBatch
 ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
     assert self.loss_initialized()
     # Switch on is_training flag in our batch.
     postprocessed_batch.set_training(True)
     builder = TFRunBuilder(self.get_session(), "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Пример #11
0
def _do_policy_eval(
    *,
    to_eval: Dict[PolicyID, List[PolicyEvalData]],
    policies: Dict[PolicyID, Policy],
    sample_collector,
    active_episodes: Dict[str, MultiAgentEpisode],
    tf_sess: Optional["tf.Session"] = None,
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
    """Call compute_actions on collected episode/model data to get next action.

    Args:
        to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
            IDs to lists of PolicyEvalData objects (items in these lists will
            be the batch's items for the model forward pass).
        policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
            obj.
        sample_collector (SampleCollector): The SampleCollector object to use.
        tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
            batching TF policy evaluations.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results: Dict[PolicyID, TensorStructType] = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches: Dict[PolicyID, Any] = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        policy: Policy = _get_or_raise(policies, policy_id)
        input_dict = sample_collector.get_inference_input_dict(policy_id)
        eval_results[policy_id] = \
            policy.compute_actions_from_input_dict(
                input_dict,
                timestep=policy.global_timestep,
                episodes=[active_episodes[t.env_id] for t in eval_data])

    if builder:
        # type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
 def compute_inner_gradients(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_inner_gradients")
     self._before_compute_grads()
     self._grads = self._inner_grads
     self._loss_inputs = self._inner_loss_inputs
     self._loss_input_dict = self._inner_loss_input_dict
     self.stats_fetches = self.a3c_stats_fetches
     fetches = self.build_compute_gradients(builder, postprocessed_batch)
     results = builder.get(fetches)
     self._after_compute_grads()
     return results
Пример #13
0
 def compute_actions(self,
                     obs_batch,
                     state_batches=None,
                     prev_action_batch=None,
                     prev_reward_batch=None,
                     episodes=None):
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self.build_compute_actions(builder, obs_batch, state_batches,
                                          prev_action_batch,
                                          prev_reward_batch)
     return builder.get(fetches)
Пример #14
0
 def compute_actions(self,
                     obs_batch,
                     state_batches=None,
                     prev_action_batch=None,
                     prev_reward_batch=None,
                     info_batch=None,
                     episodes=None,
                     **kwargs):
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self._build_compute_actions(builder, obs_batch,
                                           state_batches, prev_action_batch,
                                           prev_reward_batch)
     return builder.get(fetches)
Пример #15
0
    def compute_actions(
            self,
            obs_batch,
            state_batches = None,
            prev_action_batch = None,
            prev_reward_batch = None,
            info_batch = None,
            episodes = None,
            explore = None,
            timestep = None,
            **kwargs):
        """
        method used in PPOTFPolicy but edited to handle dict inputs at runtime  (it is handled 
            at training by existing rllib code, but not for using already trained model)
        """
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        builder = TFRunBuilder(self._sess, "compute_actions")
        if type(obs_batch) is dict:
            some_batch_data = list(obs_batch.values())[0]
            obs_batch_len = len(some_batch_data) if isinstance(some_batch_data, list) \
                else some_batch_data.shape[0]
            flattened_obs = []
            for k in self.observation_space.original_space.spaces.keys():
                if k in obs_batch:
                    obs = np.array(obs_batch[k])
                    flattened_obs.append(obs.reshape(obs_batch_len, np.prod(obs.shape[1:])))
            obs_batch = np.concatenate(flattened_obs, axis=-1)       
        else:
            obs_batch_len = len(obs_batch) if isinstance(obs_batch, list) \
                else obs_batch.shape[0]
            obs_batch = np.array(obs_batch)

        to_fetch = self._build_compute_actions(
            builder,
            obs_batch=obs_batch,
            state_batches=state_batches,
            prev_action_batch=prev_action_batch,
            prev_reward_batch=prev_reward_batch,
            explore=explore,
            timestep=timestep)

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += obs_batch_len
        return fetched
Пример #16
0
    def learn_on_batch(
            self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
        assert self.loss_initialized()

        builder = TFRunBuilder(self._sess, "learn_on_batch")

        # Callback handling.
        learn_stats = {}
        self.callbacks.on_learn_on_batch(
            policy=self, train_batch=postprocessed_batch, result=learn_stats)

        fetches = self._build_learn_on_batch(builder, postprocessed_batch)
        stats = builder.get(fetches)
        stats.update({"custom_metrics": learn_stats})
        return stats
Пример #17
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)

        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):
            # TODO(ekl): how can we make info batch available to TF code?
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder, [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Пример #18
0
 def apply_gradients(self, grads):
     if isinstance(grads, dict):
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "apply_gradients")
             outputs = {
                 pid:
                 self.policy_map[pid].build_apply_gradients(builder, grad)
                 for pid, grad in grads.items()
             }
             return {k: builder.get(v) for k, v in outputs.items()}
         else:
             return {
                 pid: self.policy_map[pid].apply_gradients(g)
                 for pid, g in grads.items()
             }
     else:
         return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Пример #19
0
 def apply_gradients(self, grads):
     if isinstance(grads, dict):
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "apply_gradients")
             outputs = {
                 pid: self.policy_map[pid]._build_apply_gradients(
                     builder, grad)
                 for pid, grad in grads.items()
             }
             return {k: builder.get(v) for k, v in outputs.items()}
         else:
             return {
                 pid: self.policy_map[pid].apply_gradients(g)
                 for pid, g in grads.items()
             }
     else:
         return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Пример #20
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes, clip_actions):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None
    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicyGraph.compute_actions.__code__):
            pending_fetches[policy_id] = policy.build_compute_actions(
                builder, [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    if clip_actions:
        for policy_id, results in eval_results.items():
            policy = _get_or_raise(policies, policy_id)
            actions, rnn_out_cols, pi_info_cols = results
            eval_results[policy_id] = (_clip_actions(actions,
                                                     policy.action_space),
                                       rnn_out_cols, pi_info_cols)

    return eval_results
Пример #21
0
    def compute_gradients(
            self, samples: SampleBatchType) -> Tuple[ModelGradients, dict]:
        """Returns a gradient computed w.r.t the specified samples.

        Returns:
            (grads, info): A list of gradients that can be applied on a
            compatible worker. In the multi-agent case, returns a dict
            of gradients keyed by policy ids. An info dictionary of
            extra metadata is also returned.

        Examples:
            >>> batch = worker.sample()
            >>> grads, info = worker.compute_gradients(samples)
        """
        if log_once("compute_gradients"):
            logger.info("Compute gradients on:\n\n{}\n".format(
                summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            grad_out, info_out = {}, {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "compute_gradients")
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid]._build_compute_gradients(
                            builder, batch))
                grad_out = {k: builder.get(v) for k, v in grad_out.items()}
                info_out = {k: builder.get(v) for k, v in info_out.items()}
            else:
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid].compute_gradients(batch))
        else:
            grad_out, info_out = (
                self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
        info_out["batch_count"] = samples.count
        if log_once("grad_out"):
            logger.info("Compute grad info:\n\n{}\n".format(
                summarize(info_out)))
        return grad_out, info_out
Пример #22
0
 def compute_apply(self, samples):
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_apply")
             for pid, batch in samples.policy_batches.items():
                 info_out[pid], _ = (
                     self.policy_map[pid].build_compute_apply(
                         builder, batch))
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 info_out[pid], _ = (
                     self.policy_map[pid].compute_apply(batch))
         return info_out
     else:
         grad_fetch, apply_fetch = (
             self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
         return grad_fetch
Пример #23
0
 def compute_gradients(self, samples):
     if isinstance(samples, MultiAgentBatch):
         grad_out, info_out = {}, {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_gradients")
             for pid, batch in samples.policy_batches.items():
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].build_compute_gradients(
                         builder, batch))
             grad_out = {k: builder.get(v) for k, v in grad_out.items()}
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].compute_gradients(batch))
         return grad_out, info_out
     else:
         return self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
             samples)
Пример #24
0
 def apply_gradients(self, grads):
     if log_once("apply_gradients"):
         logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
     if isinstance(grads, dict):
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "apply_gradients")
             outputs = {
                 pid: self.policy_map[pid]._build_apply_gradients(
                     builder, grad)
                 for pid, grad in grads.items()
             }
             return {k: builder.get(v) for k, v in outputs.items()}
         else:
             return {
                 pid: self.policy_map[pid].apply_gradients(g)
                 for pid, g in grads.items()
             }
     else:
         return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Пример #25
0
    def compute_actions(
        self,
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorType], TensorType] = None,
        prev_reward_batch: Union[List[TensorType], TensorType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs,
    ):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        builder = TFRunBuilder(self.get_session(), "compute_actions")

        input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
        if state_batches:
            for i, s in enumerate(state_batches):
                input_dict[f"state_in_{i}"] = s
        if prev_action_batch is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
        if prev_reward_batch is not None:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch

        to_fetch = self._build_compute_actions(
            builder, input_dict=input_dict, explore=explore, timestep=timestep
        )

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += (
            len(obs_batch)
            if isinstance(obs_batch, list)
            else tree.flatten(obs_batch)[0].shape[0]
        )

        return fetched
Пример #26
0
    def learn_on_batch(self, samples: SampleBatchType) -> dict:
        """Update policies based on the given batch.

        This is the equivalent to apply_gradients(compute_gradients(samples)),
        but can be optimized to avoid pulling gradients into CPU memory.

        Returns:
            info: dictionary of extra metadata from compute_gradients().

        Examples:
            >>> batch = worker.sample()
            >>> worker.learn_on_batch(samples)
        """
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch)
                else:
                    info_out[pid] = policy.learn_on_batch(batch)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            info_out = {
                DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
                .learn_on_batch(samples)
            }
        if log_once("learn_out"):
            logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
        return info_out
Пример #27
0
    def learn_on_batch(self, samples):
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch)
                else:
                    info_out[pid] = policy.learn_on_batch(batch)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            learn_on_batch_outputs = self.policy_map[
                DEFAULT_POLICY_ID].learn_on_batch(samples)

            if isinstance(learn_on_batch_outputs, tuple):
                info_out, beholder_arrays = learn_on_batch_outputs
            else:
                info_out, beholder_arrays = learn_on_batch_outputs, {}

            if self.policy_config["evaluation_config"]["beholder"]:
                with self.tf_sess.graph.as_default():
                    if self._beholder is None:
                        self._beholder = Beholder(self.io_context.log_dir)
                    self._beholder.update(self.tf_sess,
                                          arrays=beholder_arrays or None)

        if log_once("learn_out"):
            logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
        return info_out
Пример #28
0
    def learn_on_batch(
            self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
        assert self.loss_initialized()

        # Switch on is_training flag in our batch.
        postprocessed_batch.set_training(True)

        builder = TFRunBuilder(self.get_session(), "learn_on_batch")

        # Callback handling.
        learn_stats = {}
        self.callbacks.on_learn_on_batch(policy=self,
                                         train_batch=postprocessed_batch,
                                         result=learn_stats)

        fetches = self._build_learn_on_batch(builder, postprocessed_batch)
        stats = builder.get(fetches)
        stats.update({
            "custom_metrics": learn_stats,
            NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
        })
        return stats
Пример #29
0
    def compute_actions_from_input_dict(
        self,
        input_dict: Union[SampleBatch, Dict[str, TensorType]],
        explore: bool = None,
        timestep: Optional[int] = None,
        episodes: Optional[List["Episode"]] = None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        # Switch off is_training flag in our batch.
        if isinstance(input_dict, SampleBatch):
            input_dict.set_training(False)
        else:
            # Deprecated dict input.
            input_dict["is_training"] = False

        builder = TFRunBuilder(self.get_session(), "compute_actions_from_input_dict")
        obs_batch = input_dict[SampleBatch.OBS]
        to_fetch = self._build_compute_actions(
            builder, input_dict=input_dict, explore=explore, timestep=timestep
        )

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += (
            len(obs_batch)
            if isinstance(obs_batch, list)
            else len(input_dict)
            if isinstance(input_dict, SampleBatch)
            else obs_batch.shape[0]
        )

        return fetched
Пример #30
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None
    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicyGraph.compute_actions.__code__):
            # TODO(ekl): how can we make info batch available to TF code?
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder, [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    return eval_results
Пример #31
0
 def compute_actions(self,
                     obs_batch,
                     state_batches=None,
                     prev_action_batch=None,
                     prev_reward_batch=None,
                     info_batch=None,
                     episodes=None,
                     explore=None,
                     timestep=None,
                     **kwargs):
     explore = explore if explore is not None else self.config["explore"]
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self._build_compute_actions(
         builder,
         obs_batch,
         state_batches,
         prev_action_batch,
         prev_reward_batch,
         explore=explore,
         timestep=timestep
         if timestep is not None else self.global_timestep)
     # Execute session run to get action (and other fetches).
     return builder.get(fetches)
Пример #32
0
 def compute_apply(self, samples):
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_apply")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid]._build_compute_apply(
                         builder, batch))
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid].compute_apply(batch))
         return info_out
     else:
         grad_fetch, apply_fetch = (
             self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
         return grad_fetch
Пример #33
0
 def learn_on_batch(self, samples):
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid]._build_learn_on_batch(
                         builder, batch))
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid].learn_on_batch(batch))
         return info_out
     else:
         grad_fetch, apply_fetch = (
             self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
         return grad_fetch
Пример #34
0
 def compute_gradients(self, samples):
     if isinstance(samples, MultiAgentBatch):
         grad_out, info_out = {}, {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_gradients")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid]._build_compute_gradients(
                         builder, batch))
             grad_out = {k: builder.get(v) for k, v in grad_out.items()}
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].compute_gradients(batch))
     else:
         grad_out, info_out = (
             self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
     info_out["batch_count"] = samples.count
     return grad_out, info_out
Пример #35
0
 def learn_on_batch(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)
Пример #36
0
 def compute_apply(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_apply")
     fetches = self._build_compute_apply(builder, postprocessed_batch)
     return builder.get(fetches)
Пример #37
0
 def compute_gradients(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Пример #38
0
 def apply_gradients(self, gradients):
     builder = TFRunBuilder(self._sess, "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     return builder.get(fetches)
Пример #39
0
 def learn_on_batch(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)