Пример #1
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):
            # TODO(ekl): how can we make info batch available to TF code?
            # print('First: ' + str(eval_data))
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder, [t.obs for t in eval_data],
                [t.neighbor_obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            # print('Second: ' + str(eval_data))
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                [t.neighbor_obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Пример #2
0
    def learn_on_batch(self, samples):
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                '''
                 For Attention !!!!!!!!!!!!!!!!!!!!
                 Build a dict for neighbor agents (also include current agent) observation 
                '''
                neighbor_batch_dic = {}
                neighbor_pid_list = [
                    pid_ for pid_ in
                    self.traffic_light_node_dict[self.inter_num_2_id(
                        int(pid.split('_')[1]))]['adjacency_row']
                    if pid_ != None
                ]
                for neighbor_pid in neighbor_pid_list:
                    neighbor_batch_dic['policy_{}'.format(
                        neighbor_pid)] = samples.policy_batches[
                            'policy_{}'.format(neighbor_pid)]
                # neighbor_batch_dic[pid] = samples.policy_batches[pid]
                # ------------------------------------------------------------------

                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch, neighbor_batch_dic)
                else:
                    info_out[pid] = policy.learn_on_batch(
                        batch, neighbor_batch_dic)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
                samples)
        if log_once("learn_out"):
            logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
        return info_out
Пример #3
0
 def compute_actions(self,
                     obs_batch,
                     neighbor_obs_batch,
                     state_batches=None,
                     prev_action_batch=None,
                     prev_reward_batch=None,
                     info_batch=None,
                     episodes=None,
                     **kwargs):
     builder = TFRunBuilder(self._sess, "compute_actions")
     fetches = self._build_compute_actions(
         builder,
         obs_batch,
         # neighbor_obs_batch,
         state_batches,
         prev_action_batch,
         prev_reward_batch)
     return builder.get(fetches)
Пример #4
0
 def apply_gradients(self, grads):
     if log_once("apply_gradients"):
         logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
     if isinstance(grads, dict):
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "apply_gradients")
             outputs = {
                 pid:
                 self.policy_map[pid]._build_apply_gradients(builder, grad)
                 for pid, grad in grads.items()
             }
             return {k: builder.get(v) for k, v in outputs.items()}
         else:
             return {
                 pid: self.policy_map[pid].apply_gradients(g)
                 for pid, g in grads.items()
             }
     else:
         return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Пример #5
0
 def learn_on_batch(self, postprocessed_batch, neighbor_batch_dic):
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "learn_on_batch")
     fetches = self._build_learn_on_batch(builder, postprocessed_batch,
                                          neighbor_batch_dic)
     return builder.get(fetches)
Пример #6
0
 def apply_gradients(self, gradients):
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "apply_gradients")
     fetches = self._build_apply_gradients(builder, gradients)
     builder.get(fetches)
Пример #7
0
 def compute_gradients(self, postprocessed_batch, neighbor_batch_dic):
     assert self.loss_initialized()
     builder = TFRunBuilder(self._sess, "compute_gradients")
     fetches = self._build_compute_gradients(builder, postprocessed_batch,
                                             neighbor_batch_dic)
     return builder.get(fetches)
Пример #8
0
    def compute_gradients(self, samples):
        if log_once("compute_gradients"):
            logger.info("Compute gradients on:\n\n{}\n".format(
                summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            grad_out, info_out = {}, {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "compute_gradients")
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    '''
                     For Attention !!!!!!!!!!!!!!!!!!!!
                     Build a dict for neighbor agents (also include current agent) observation 
                    '''
                    neighbor_batch_dic = {}
                    neighbor_pid_list = [
                        pid_ for pid_ in
                        self.traffic_light_node_dict[self.inter_num_2_id(
                            int(pid.split('_')[1]))]['adjacency_row']
                        if pid_ != None
                    ]
                    for neighbor_pid in neighbor_pid_list:
                        neighbor_batch_dic['policy_{}'.format(
                            neighbor_pid)] = samples.policy_batches[
                                'policy_{}'.format(neighbor_pid)]
                    # neighbor_batch_dic[pid] = samples.policy_batches[pid]
                    # ------------------------------------------------------------------

                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid]._build_compute_gradients(
                            builder, batch, neighbor_batch_dic))
                grad_out = {k: builder.get(v) for k, v in grad_out.items()}
                info_out = {k: builder.get(v) for k, v in info_out.items()}
            else:
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    '''
                     For Attention !!!!!!!!!!!!!!!!!!!!
                     Build a dict for neighbor agents (also include current agent) observation 
                    '''
                    neighbor_batch_dic = {}
                    neighbor_pid_list = [
                        pid_ for pid_ in
                        self.traffic_light_node_dict[self.inter_num_2_id(
                            int(pid.split('_')[1]))]['adjacency_row']
                        if pid_ != None
                    ]
                    for neighbor_pid in neighbor_pid_list:
                        neighbor_batch_dic['policy_{}'.format(
                            neighbor_pid)] = samples.policy_batches[
                                'policy_{}'.format(neighbor_pid)]
                    # neighbor_batch_dic[pid] = samples.policy_batches[pid]
                    # ------------------------------------------------------------------
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid].compute_gradients(
                            batch, neighbor_batch_dic))
        else:
            grad_out, info_out = (
                self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
        info_out["batch_count"] = samples.count
        if log_once("grad_out"):
            logger.info("Compute grad info:\n\n{}\n".format(
                summarize(info_out)))
        return grad_out, info_out