Example #1
0
 def compute_actions(self,
                     obs_batch,
                     state_batches=None,
                     prev_action_batch=None,
                     prev_reward_batch=None,
                     info_batch=None,
                     episodes=None,
                     **kwargs):
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self._build_compute_actions(builder, obs_batch,
                                           state_batches, prev_action_batch,
                                           prev_reward_batch)
     return builder.get(fetches)
Example #2
0
 def apply_gradients(self, grads):
     if isinstance(grads, dict):
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "apply_gradients")
             outputs = {
                 pid: self.policy_map[pid]._build_apply_gradients(
                     builder, grad)
                 for pid, grad in grads.items()
             }
             return {k: builder.get(v) for k, v in outputs.items()}
         else:
             return {
                 pid: self.policy_map[pid].apply_gradients(g)
                 for pid, g in grads.items()
             }
     else:
         return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Example #3
0
    def compute_log_likelihoods(self,
                                actions,
                                obs_batch,
                                state_batches=None,
                                prev_action_batch=None,
                                prev_reward_batch=None):
        if self._log_likelihood is None:
            raise ValueError("Cannot compute log-prob/likelihood w/o a "
                             "self._log_likelihood op!")

        # Do the forward pass through the model to capture the parameters
        # for the action distribution, then do a logp on that distribution.
        builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
        # Feed actions (for which we want logp values) into graph.
        builder.add_feed_dict({self._action_input: actions})
        # Feed observations.
        builder.add_feed_dict({self._obs_input: obs_batch})
        # Internal states.
        state_batches = state_batches or []
        if len(self._state_inputs) != len(state_batches):
            raise ValueError(
                "Must pass in RNN state batches for placeholders {}, got {}".
                format(self._state_inputs, state_batches))
        builder.add_feed_dict(
            {k: v
             for k, v in zip(self._state_inputs, state_batches)})
        if state_batches:
            builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
        # Prev-a and r.
        if self._prev_action_input is not None and \
           prev_action_batch is not None:
            builder.add_feed_dict({self._prev_action_input: prev_action_batch})
        if self._prev_reward_input is not None and \
           prev_reward_batch is not None:
            builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
        # Fetch the log_likelihoods output and return.
        fetches = builder.add_fetches([self._log_likelihood])
        return builder.get(fetches)[0]
Example #4
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None
    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicyGraph.compute_actions.__code__):
            # TODO(ekl): how can we make info batch available to TF code?
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder, [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    return eval_results
Example #5
0
    def compute_actions_from_input_dict(
        self,
        input_dict: Union[SampleBatch, Dict[str, TensorType]],
        explore: bool = None,
        timestep: Optional[int] = None,
        episodes: Optional[List["Episode"]] = None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        # Switch off is_training flag in our batch.
        if isinstance(input_dict, SampleBatch):
            input_dict.set_training(False)
        else:
            # Deprecated dict input.
            input_dict["is_training"] = False

        builder = TFRunBuilder(self.get_session(), "compute_actions_from_input_dict")
        obs_batch = input_dict[SampleBatch.OBS]
        to_fetch = self._build_compute_actions(
            builder, input_dict=input_dict, explore=explore, timestep=timestep
        )

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += (
            len(obs_batch)
            if isinstance(obs_batch, list)
            else len(input_dict)
            if isinstance(input_dict, SampleBatch)
            else obs_batch.shape[0]
        )

        return fetched
Example #6
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None
    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicyGraph.compute_actions.__code__):
            # TODO(ekl): how can we make info batch available to TF code?
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder, [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    return eval_results
Example #7
0
    def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
        assert self.loss_initialized()

        # Switch on is_training flag in our batch.
        postprocessed_batch.set_training(True)

        builder = TFRunBuilder(self.get_session(), "learn_on_batch")

        # Callback handling.
        learn_stats = {}
        self.callbacks.on_learn_on_batch(
            policy=self, train_batch=postprocessed_batch, result=learn_stats
        )

        fetches = self._build_learn_on_batch(builder, postprocessed_batch)
        stats = builder.get(fetches)
        stats.update(
            {
                "custom_metrics": learn_stats,
                NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
            }
        )
        return stats
Example #8
0
 def compute_apply(self, samples):
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_apply")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid]._build_compute_apply(
                         builder, batch))
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid].compute_apply(batch))
         return info_out
     else:
         grad_fetch, apply_fetch = (
             self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
         return grad_fetch
Example #9
0
 def compute_apply(self, samples):
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_apply")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid]._build_compute_apply(
                         builder, batch))
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid].compute_apply(batch))
         return info_out
     else:
         grad_fetch, apply_fetch = (
             self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
         return grad_fetch
Example #10
0
 def compute_actions(self,
                     obs_batch,
                     state_batches=None,
                     prev_action_batch=None,
                     prev_reward_batch=None,
                     info_batch=None,
                     episodes=None,
                     explore=None,
                     timestep=None,
                     **kwargs):
     explore = explore if explore is not None else self.config["explore"]
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self._build_compute_actions(
         builder,
         obs_batch,
         state_batches,
         prev_action_batch,
         prev_reward_batch,
         explore=explore,
         timestep=timestep
         if timestep is not None else self.global_timestep)
     # Execute session run to get action (and other fetches).
     return builder.get(fetches)
Example #11
0
 def compute_gradients(self, samples):
     if isinstance(samples, MultiAgentBatch):
         grad_out, info_out = {}, {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_gradients")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid]._build_compute_gradients(
                         builder, batch))
             grad_out = {k: builder.get(v) for k, v in grad_out.items()}
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].compute_gradients(batch))
     else:
         grad_out, info_out = (
             self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
     info_out["batch_count"] = samples.count
     return grad_out, info_out
Example #12
0
 def compute_gradients(self, samples):
     if isinstance(samples, MultiAgentBatch):
         grad_out, info_out = {}, {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_gradients")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid]._build_compute_gradients(
                         builder, batch))
             grad_out = {k: builder.get(v) for k, v in grad_out.items()}
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].compute_gradients(batch))
     else:
         grad_out, info_out = (
             self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
     info_out["batch_count"] = samples.count
     return grad_out, info_out
Example #13
0
 def compute_apply(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_apply")
     fetches = self._build_compute_apply(builder, postprocessed_batch)
     return builder.get(fetches)
Example #14
0
def _env_runner(async_vector_env,
                extra_batch_callback,
                policies,
                policy_mapping_fn,
                unroll_length,
                horizon,
                obs_filters,
                clip_rewards,
                pack,
                tf_sess=None):
    """This implements the common experience collection logic.

    Args:
        async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
        extra_batch_callback (fn): function to send extra batch data to.
        policies (dict): Map of policy ids to PolicyGraph instances.
        policy_mapping_fn (func): Function that maps agent ids to policy ids.
            This is called when an agent first enters the environment. The
            agent is then "bound" to the returned policy for the episode.
        unroll_length (int): Number of episode steps before `SampleBatch` is
            yielded. Set to infinity to yield complete episodes.
        horizon (int): Horizon of the episode.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        clip_rewards (bool): Whether to clip rewards before postprocessing.
        pack (bool): Whether to pack multiple episodes into each batch. This
            guarantees batches will be exactly `unroll_length` in size.
        tf_sess (Session|None): Optional tensorflow session to use for batching
            TF policy evaluations.

    Yields:
        rollout (SampleBatch): Object containing state, action, reward,
            terminal condition, and other fields as dictated by `policy`.
    """

    try:
        if not horizon:
            horizon = (
                async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
    except Exception:
        print("*** WARNING ***: no episode horizon specified, assuming inf")
    if not horizon:
        horizon = float("inf")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            return MultiAgentSampleBatchBuilder(policies, clip_rewards)

    def new_episode():
        return MultiAgentEpisode(policies, policy_mapping_fn,
                                 get_batch_builder, extra_batch_callback)

    active_episodes = defaultdict(new_episode)

    while True:
        # Get observations from all ready agents
        unfiltered_obs, rewards, dones, infos, off_policy_actions = \
            async_vector_env.poll()

        # Map of policy_id to list of PolicyEvalData
        to_eval = defaultdict(list)

        # Map of env_id -> agent_id -> action replies
        actions_to_send = defaultdict(dict)

        # For each environment
        for env_id, agent_obs in unfiltered_obs.items():
            new_episode = env_id not in active_episodes
            episode = active_episodes[env_id]
            if not new_episode:
                episode.length += 1
                episode.batch_builder.count += 1
                episode._add_agent_rewards(rewards[env_id])

            # Check episode termination conditions
            if dones[env_id]["__all__"] or episode.length >= horizon:
                all_done = True
                atari_metrics = _fetch_atari_metrics(async_vector_env)
                if atari_metrics is not None:
                    for m in atari_metrics:
                        yield m
                else:
                    yield RolloutMetrics(episode.length, episode.total_reward,
                                         dict(episode.agent_rewards))
            else:
                all_done = False
                # At least send an empty dict if not done
                actions_to_send[env_id] = {}

            # For each agent in the environment
            for agent_id, raw_obs in agent_obs.items():
                policy_id = episode.policy_for(agent_id)
                filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
                agent_done = bool(all_done or dones[env_id].get(agent_id))
                if not agent_done:
                    to_eval[policy_id].append(
                        PolicyEvalData(env_id, agent_id, filtered_obs,
                                       episode.rnn_state_for(agent_id)))

                last_observation = episode.last_observation_for(agent_id)
                episode._set_last_observation(agent_id, filtered_obs)

                # Record transition info if applicable
                if last_observation is not None and \
                        infos[env_id][agent_id].get("training_enabled", True):
                    episode.batch_builder.add_values(
                        agent_id,
                        policy_id,
                        t=episode.length - 1,
                        eps_id=episode.episode_id,
                        obs=last_observation,
                        actions=episode.last_action_for(agent_id),
                        rewards=rewards[env_id][agent_id],
                        dones=agent_done,
                        infos=infos[env_id][agent_id],
                        new_obs=filtered_obs,
                        **episode.last_pi_info_for(agent_id))

            # Cut the batch if we're not packing multiple episodes into one,
            # or if we've exceeded the requested batch size.
            if episode.batch_builder.has_pending_data():
                if (all_done and not pack) or \
                        episode.batch_builder.count >= unroll_length:
                    yield episode.batch_builder.build_and_reset()
                elif all_done:
                    # Make sure postprocessor stays within one episode
                    episode.batch_builder.postprocess_batch_so_far()

            if all_done:
                # Handle episode termination
                batch_builder_pool.append(episode.batch_builder)
                del active_episodes[env_id]
                resetted_obs = async_vector_env.try_reset(env_id)
                if resetted_obs is None:
                    # Reset not supported, drop this env from the ready list
                    assert horizon == float("inf"), \
                        "Setting episode horizon requires reset() support."
                else:
                    # Creates a new episode
                    episode = active_episodes[env_id]
                    for agent_id, raw_obs in resetted_obs.items():
                        policy_id = episode.policy_for(agent_id)
                        filtered_obs = _get_or_raise(obs_filters,
                                                     policy_id)(raw_obs)
                        episode._set_last_observation(agent_id, filtered_obs)
                        to_eval[policy_id].append(
                            PolicyEvalData(env_id, agent_id, filtered_obs,
                                           episode.rnn_state_for(agent_id)))

        # Batch eval policy actions if possible
        if tf_sess:
            builder = TFRunBuilder(tf_sess, "policy_eval")
            pending_fetches = {}
        else:
            builder = None
        eval_results = {}
        rnn_in_cols = {}
        for policy_id, eval_data in to_eval.items():
            rnn_in = _to_column_format([t.rnn_state for t in eval_data])
            rnn_in_cols[policy_id] = rnn_in
            policy = _get_or_raise(policies, policy_id)
            if builder and (policy.compute_actions.__code__ is
                            TFPolicyGraph.compute_actions.__code__):
                pending_fetches[policy_id] = policy.build_compute_actions(
                    builder, [t.obs for t in eval_data],
                    rnn_in,
                    is_training=True)
            else:
                eval_results[policy_id] = policy.compute_actions(
                    [t.obs for t in eval_data],
                    rnn_in,
                    is_training=True,
                    episodes=[active_episodes[t.env_id] for t in eval_data])
        if builder:
            for k, v in pending_fetches.items():
                eval_results[k] = builder.get(v)

        # Record the policy eval results
        for policy_id, eval_data in to_eval.items():
            actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
            if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
                raise ValueError(
                    "Length of RNN in did not match RNN out, got: "
                    "{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
            # Add RNN state info
            for f_i, column in enumerate(rnn_in_cols[policy_id]):
                pi_info_cols["state_in_{}".format(f_i)] = column
            for f_i, column in enumerate(rnn_out_cols):
                pi_info_cols["state_out_{}".format(f_i)] = column
            # Save output rows
            for i, action in enumerate(actions):
                env_id = eval_data[i].env_id
                agent_id = eval_data[i].agent_id
                actions_to_send[env_id][agent_id] = action
                episode = active_episodes[env_id]
                episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
                episode._set_last_pi_info(
                    agent_id, {k: v[i]
                               for k, v in pi_info_cols.items()})
                if env_id in off_policy_actions and \
                        agent_id in off_policy_actions[env_id]:
                    episode._set_last_action(
                        agent_id, off_policy_actions[env_id][agent_id])
                else:
                    episode._set_last_action(agent_id, action)

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        async_vector_env.send_actions(dict(actions_to_send))
Example #15
0
 def apply_gradients(self, gradients: ModelGradients) -> None:
     assert self.loss_initialized()
     builder = TFRunBuilder(self.get_session(), "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     builder.get(fetches)
Example #16
0
 def learn_on_batch(self, postprocessed_batch):
     assert self._loss is not None, "Loss not initialized"
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)
Example #17
0
 def apply_gradients(self, gradients):
     assert self._loss is not None, "Loss not initialized"
     builder = TFRunBuilder(self._sess, "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     builder.get(fetches)
Example #18
0
def _do_policy_eval(
    *,
    to_eval: Dict[PolicyID, List[PolicyEvalData]],
    policies: Dict[PolicyID, Policy],
    policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
    sample_collector,
    active_episodes: Dict[str, MultiAgentEpisode],
    tf_sess: Optional["tf.Session"] = None,
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
    """Call compute_actions on collected episode/model data to get next action.

    Args:
        to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
            IDs to lists of PolicyEvalData objects (items in these lists will
            be the batch's items for the model forward pass).
        policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
            obj.
        sample_collector (SampleCollector): The SampleCollector object to use.
        tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
            batching TF policy evaluations.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results: Dict[PolicyID, TensorStructType] = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches: Dict[PolicyID, Any] = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        # In case the policyID has been removed from this worker, we need to
        # re-assign policy_id and re-lookup the Policy object to use.
        try:
            policy: Policy = _get_or_raise(policies, policy_id)
        except ValueError:
            policy_id = policy_mapping_fn(eval_data[0].agent_id,
                                          active_episodes[eval_data[0].env_id])
            policy: Policy = _get_or_raise(policies, policy_id)

        input_dict = sample_collector.get_inference_input_dict(policy_id)
        eval_results[policy_id] = \
            policy.compute_actions_from_input_dict(
                input_dict,
                timestep=policy.global_timestep,
                episodes=[active_episodes[t.env_id] for t in eval_data])

    if builder:
        # type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Example #19
0
 def apply_gradients(self, gradients):
     builder = TFRunBuilder(self._sess, "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     return builder.get(fetches)
Example #20
0
 def learn_on_batch(self, postprocessed_batch):
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)
Example #21
0
 def compute_gradients(self, postprocessed_batch):
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Example #22
0
def _do_policy_eval(*, to_eval, policies, active_episodes, tf_sess=None):
    """Call compute_actions on collected episode/model data to get next action.

    Args:
        tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
            batching TF policy evaluations.
        to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to
            lists of PolicyEvalData objects.
        policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
        active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in = [t.rnn_state for t in eval_data]
        policy = _get_or_raise(policies, policy_id)
        # If tf (non eager) AND TFPolicy's compute_action method has not been
        # overridden -> Use `policy._build_compute_actions()`.
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):

            obs_batch = [t.obs for t in eval_data]
            state_batches = _to_column_format(rnn_in)
            # TODO(ekl): how can we make info batch available to TF code?
            prev_action_batch = [t.prev_action for t in eval_data]
            prev_reward_batch = [t.prev_reward for t in eval_data]

            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=obs_batch,
                state_batches=state_batches,
                prev_action_batch=prev_action_batch,
                prev_reward_batch=prev_reward_batch,
                timestep=policy.global_timestep)
        else:
            rnn_in_cols = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Example #23
0
 def apply_gradients(self, gradients):
     builder = TFRunBuilder(self._sess, "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     builder.get(fetches)
Example #24
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in = [t.rnn_state for t in eval_data]
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):

            obs_batch = [t.obs for t in eval_data]
            state_batches = _to_column_format(rnn_in)

            # TODO(ekl): how can we make info batch available to TF code?
            obs_batch = [t.obs for t in eval_data]
            prev_action_batch = [t.prev_action for t in eval_data]
            prev_reward_batch = [t.prev_reward for t in eval_data]

            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=obs_batch,
                state_batches=state_batches,
                prev_action_batch=prev_action_batch,
                prev_reward_batch=prev_reward_batch,
                timestep=policy.global_timestep)
        else:
            # TODO(sven): Does this work for LSTM torch?
            rnn_in_cols = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Example #25
0
def _do_policy_eval(
        *,
        to_eval: Dict[PolicyID, List[PolicyEvalData]],
        policies: Dict[PolicyID, Policy],
        active_episodes: Dict[str, MultiAgentEpisode],
        tf_sess=None,
        _use_trajectory_view_api=False
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
    """Call compute_actions on collected episode/model data to get next action.

    Args:
        to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
            IDs to lists of PolicyEvalData objects (items in these lists will
            be the batch's items for the model forward pass).
        policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
            obj.
        active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
            batching TF policy evaluations.
        _use_trajectory_view_api (bool): Whether to use the (experimental)
            `_use_trajectory_view_api` procedure to collect samples.
            Default: False.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results: Dict[PolicyID, TensorStructType] = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches: Dict[PolicyID, Any] = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    # type: PolicyID, PolicyEvalData
    for policy_id, eval_data in to_eval.items():
        rnn_in: List[List[Any]] = [t.rnn_state for t in eval_data]
        policy: Policy = _get_or_raise(policies, policy_id)
        # If tf (non eager) AND TFPolicy's compute_action method has not been
        # overridden -> Use `policy._build_compute_actions()`.
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):

            obs_batch: List[EnvObsType] = [t.obs for t in eval_data]
            state_batches: StateBatch = _to_column_format(rnn_in)
            # TODO(ekl): how can we make info batch available to TF code?
            prev_action_batch = [t.prev_action for t in eval_data]
            prev_reward_batch = [t.prev_reward for t in eval_data]

            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=obs_batch,
                state_batches=state_batches,
                prev_action_batch=prev_action_batch,
                prev_reward_batch=prev_reward_batch,
                timestep=policy.global_timestep)
        else:
            rnn_in_cols: StateBatch = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        # type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
        for pid, v in pending_fetches.items():
            eval_results[pid] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Example #26
0
 def compute_gradients(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Example #27
0
 def apply_gradients(self, gradients):
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     builder.get(fetches)
Example #28
0
 def learn_on_batch(
         self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)
Example #29
0
 def compute_gradients(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch)
     return builder.get(fetches)
Example #30
0
 def learn_on_batch(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)
Example #31
0
 def learn_on_batch(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch)
     return builder.get(fetches)
Example #32
0
 def compute_apply(self, postprocessed_batch):
     builder = TFRunBuilder(self._sess, "compute_apply")
     fetches = self._build_compute_apply(builder, postprocessed_batch)
     return builder.get(fetches)