def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? # print('First: ' + str(eval_data)) pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], [t.neighbor_obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: # print('Second: ' + str(eval_data)) eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], [t.neighbor_obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue ''' For Attention !!!!!!!!!!!!!!!!!!!! Build a dict for neighbor agents (also include current agent) observation ''' neighbor_batch_dic = {} neighbor_pid_list = [ pid_ for pid_ in self.traffic_light_node_dict[self.inter_num_2_id( int(pid.split('_')[1]))]['adjacency_row'] if pid_ != None ] for neighbor_pid in neighbor_pid_list: neighbor_batch_dic['policy_{}'.format( neighbor_pid)] = samples.policy_batches[ 'policy_{}'.format(neighbor_pid)] # neighbor_batch_dic[pid] = samples.policy_batches[pid] # ------------------------------------------------------------------ policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch, neighbor_batch_dic) else: info_out[pid] = policy.learn_on_batch( batch, neighbor_batch_dic) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch( samples) if log_once("learn_out"): logger.debug("Training out:\n\n{}\n".format(summarize(info_out))) return info_out
def compute_actions(self, obs_batch, neighbor_obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self._build_compute_actions( builder, obs_batch, # neighbor_obs_batch, state_batches, prev_action_batch, prev_reward_batch) return builder.get(fetches)
def apply_gradients(self, grads): if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients(builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def learn_on_batch(self, postprocessed_batch, neighbor_batch_dic): assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch, neighbor_batch_dic) return builder.get(fetches)
def apply_gradients(self, gradients): assert self.loss_initialized() builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches)
def compute_gradients(self, postprocessed_batch, neighbor_batch_dic): assert self.loss_initialized() builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch, neighbor_batch_dic) return builder.get(fetches)
def compute_gradients(self, samples): if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue ''' For Attention !!!!!!!!!!!!!!!!!!!! Build a dict for neighbor agents (also include current agent) observation ''' neighbor_batch_dic = {} neighbor_pid_list = [ pid_ for pid_ in self.traffic_light_node_dict[self.inter_num_2_id( int(pid.split('_')[1]))]['adjacency_row'] if pid_ != None ] for neighbor_pid in neighbor_pid_list: neighbor_batch_dic['policy_{}'.format( neighbor_pid)] = samples.policy_batches[ 'policy_{}'.format(neighbor_pid)] # neighbor_batch_dic[pid] = samples.policy_batches[pid] # ------------------------------------------------------------------ grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch, neighbor_batch_dic)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue ''' For Attention !!!!!!!!!!!!!!!!!!!! Build a dict for neighbor agents (also include current agent) observation ''' neighbor_batch_dic = {} neighbor_pid_list = [ pid_ for pid_ in self.traffic_light_node_dict[self.inter_num_2_id( int(pid.split('_')[1]))]['adjacency_row'] if pid_ != None ] for neighbor_pid in neighbor_pid_list: neighbor_batch_dic['policy_{}'.format( neighbor_pid)] = samples.policy_batches[ 'policy_{}'.format(neighbor_pid)] # neighbor_batch_dic[pid] = samples.policy_batches[pid] # ------------------------------------------------------------------ grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients( batch, neighbor_batch_dic)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count if log_once("grad_out"): logger.info("Compute grad info:\n\n{}\n".format( summarize(info_out))) return grad_out, info_out