def compute_gradients(self, samples): if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count if log_once("grad_out"): logger.info("Compute grad info:\n\n{}\n".format( summarize(info_out))) return grad_out, info_out
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch) else: info_out[pid] = policy.learn_on_batch(batch) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch( samples) if log_once("learn_out"): logger.info("Training output:\n\n{}\n".format(summarize(info_out))) return info_out
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid]._build_learn_on_batch( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid].learn_on_batch(batch)) else: info_out, _ = ( self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples)) if log_once("learn_out"): logger.info("Training output:\n\n{}\n".format(summarize(info_out))) return info_out
def sample(self): """Evaluate the current policies and return a batch of experiences. Return: SampleBatch|MultiAgentBatch from evaluating the current policies. """ if self._fake_sampler and self.last_batch is not None: return self.last_batch if log_once("sample_start"): logger.info("Generating sample batch of size {}".format( self.sample_batch_size)) batches = [self.input_reader.next()] steps_so_far = batches[0].count # In truncate_episodes mode, never pull more than 1 batch per env. # This avoids over-running the target batch size. if self.batch_mode == "truncate_episodes": max_batches = self.num_envs else: max_batches = float("inf") while steps_so_far < self.sample_batch_size and len( batches) < max_batches: batch = self.input_reader.next() steps_so_far += batch.count batches.append(batch) batch = batches[0].concat_samples(batches) if self.callbacks.get("on_sample_end"): self.callbacks["on_sample_end"]({ "evaluator": self, "samples": batch }) # Always do writes prior to compression for consistency and to allow # for better compression inside the writer. self.output_writer.write(batch) # Do off-policy estimation if needed if self.reward_estimators: for sub_batch in batch.split_by_episode(): for estimator in self.reward_estimators: estimator.process(sub_batch) if log_once("sample_end"): logger.info("Completed sample batch:\n\n{}\n".format( summarize(batch))) if self.compress_observations == "bulk": batch.compress(bulk=True) elif self.compress_observations: batch.compress() if self._fake_sampler: self.last_batch = batch return batch
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in = [t.rnn_state for t in eval_data] policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): rnn_in_cols = _to_column_format(rnn_in) # TODO(ekl): how can we make info batch available to TF code? # TODO(sven): Return dict from _build_compute_actions. # it's becoming more and more unclear otherwise, what's where in # the return tuple. pending_fetches[policy_id] = policy._build_compute_actions( builder, obs_batch=[t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], timestep=policy.global_timestep) else: # TODO(sven): Does this work for LSTM torch? rnn_in_cols = [ np.stack([row[i] for row in rnn_in]) for i in range(len(rnn_in[0])) ] eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data], timestep=policy.global_timestep) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? # print('First: ' + str(eval_data)) pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], [t.neighbor_obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: # print('Second: ' + str(eval_data)) eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], [t.neighbor_obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def value_function(self): assert self.cur_instance, "must call forward first" with tf.variable_scope(self.variable_scope): with tf.variable_scope("value_function", reuse=tf.AUTO_REUSE): # Simple case: sharing the feature layer if self.model_config["vf_share_layers"]: return tf.reshape( linear(self.cur_instance.last_layer, 1, "value_function", normc_initializer(1.0)), [-1]) # Create a new separate model with no RNN state, etc. branch_model_config = self.model_config.copy() branch_model_config["free_log_std"] = False if branch_model_config["use_lstm"]: branch_model_config["use_lstm"] = False if log_once("vf_warn"): logger.warning( "It is not recommended to use a LSTM model " "with vf_share_layers=False (consider setting " "it to True). If you want to not share " "layers, you can implement a custom LSTM " "model that overrides the value_function() " "method.") branch_instance = self.legacy_model_cls( self.cur_instance.input_dict, self.obs_space, self.action_space, 1, branch_model_config, state_in=None, seq_lens=None) return tf.reshape(branch_instance.outputs, [-1])
def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): if timeline_dir: from tensorflow.python.client import timeline run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() start = time.time() fetches = sess.run(ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict) trace = timeline.Timeline(step_stats=run_metadata.step_stats) global _count outf = os.path.join( timeline_dir, "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10)) _count += 1 trace_file = open(outf, "w") logger.info("Wrote tf timeline ({} s) to {}".format( time.time() - start, os.path.abspath(outf))) trace_file.write(trace.generate_chrome_trace_format()) else: if log_once("tf_timeline"): logger.info( "Executing TF run without tracing. To dump TF timeline traces " "to disk, set the TF_TIMELINE_DIR environment variable.") fetches = sess.run(ops, feed_dict=feed_dict) return fetches
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue ''' For Attention !!!!!!!!!!!!!!!!!!!! Build a dict for neighbor agents (also include current agent) observation ''' neighbor_batch_dic = {} neighbor_pid_list = [ pid_ for pid_ in self.traffic_light_node_dict[self.inter_num_2_id( int(pid.split('_')[1]))]['adjacency_row'] if pid_ != None ] for neighbor_pid in neighbor_pid_list: neighbor_batch_dic['policy_{}'.format( neighbor_pid)] = samples.policy_batches[ 'policy_{}'.format(neighbor_pid)] # neighbor_batch_dic[pid] = samples.policy_batches[pid] # ------------------------------------------------------------------ policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch, neighbor_batch_dic) else: info_out[pid] = policy.learn_on_batch( batch, neighbor_batch_dic) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch( samples) if log_once("learn_out"): logger.debug("Training out:\n\n{}\n".format(summarize(info_out))) return info_out
def postprocess_batch_so_far(self, episode): """Apply policy postprocessors to any unprocessed rows. This pushes the postprocessed per-agent batches onto the per-policy builders, clearing per-agent state. Arguments: episode: current MultiAgentEpisode object or None """ # Materialize the batches so far pre_batches = {} for agent_id, builder in self.agent_builders.items(): pre_batches[agent_id] = ( self.policy_map[self.agent_to_policy[agent_id]], builder.build_and_reset()) # Apply postprocessor post_batches = {} if self.clip_rewards: for _, (_, pre_batch) in pre_batches.items(): pre_batch["rewards"] = np.sign(pre_batch["rewards"]) for agent_id, (_, pre_batch) in pre_batches.items(): other_batches = pre_batches.copy() del other_batches[agent_id] policy = self.policy_map[self.agent_to_policy[agent_id]] if any(pre_batch["dones"][:-1]) or len(set( pre_batch["eps_id"])) > 1: raise ValueError( "Batches sent to postprocessing must only contain steps " "from a single trajectory.", pre_batch) post_batches[agent_id] = policy.postprocess_trajectory( pre_batch, other_batches, episode) if log_once("after_post"): logger.info( "Trajectory fragment after postprocess_trajectory():\n\n{}\n". format(summarize(post_batches))) # Append into policy batches and reset for agent_id, post_batch in sorted(post_batches.items()): self.policy_builders[self.agent_to_policy[agent_id]].add_batch( post_batch) if self.postp_callback: self.postp_callback({ "episode": episode, "agent_id": agent_id, "pre_batch": pre_batches[agent_id], "post_batch": post_batch, "all_pre_batches": pre_batches, }) self.agent_builders.clear() self.agent_to_policy.clear()
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch) else: info_out[pid] = policy.learn_on_batch(batch) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: learn_on_batch_outputs = self.policy_map[ DEFAULT_POLICY_ID].learn_on_batch(samples) if isinstance(learn_on_batch_outputs, tuple): info_out, beholder_arrays = learn_on_batch_outputs else: info_out, beholder_arrays = learn_on_batch_outputs, {} if self.policy_config["evaluation_config"]["beholder"]: with self.tf_sess.graph.as_default(): if self._beholder is None: self._beholder = Beholder(self.io_context.log_dir) self._beholder.update(self.tf_sess, arrays=beholder_arrays or None) if log_once("learn_out"): logger.info("Training output:\n\n{}\n".format(summarize(info_out))) return info_out
def _get_loss_inputs_dict(self, batch): feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req == 0 and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: for k, ph in self._loss_inputs: feed_dict[ph] = batch[k] return feed_dict if self._state_inputs: max_seq_len = self._max_seq_len dynamic_max = True else: max_seq_len = self._batch_divisibility_req dynamic_max = False # RNN or multi-agent case feature_keys = [k for k, v in self._loss_inputs] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens if log_once("rnn_feed_dict"): logger.info("Padded input for RNN:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, }))) return feed_dict
def _initialize_loss(self, loss, loss_inputs): self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) for i, ph in enumerate(self._state_inputs): self._loss_input_dict["state_in_{}".format(i)] = ph if self.model: self._loss = self.model.custom_loss(loss, self._loss_input_dict) self._stats_fetches.update({ "model": self.model.metrics() if isinstance(self.model, ModelV2) else self.model.custom_stats() }) else: self._loss = loss self._optimizer = self.optimizer() self._grads_and_vars = [ (g, v) for (g, v) in self.gradients(self._optimizer, self._loss) if g is not None ] self._grads = [g for (g, v) in self._grads_and_vars] # TODO(sven/ekl): Deprecate support for v1 models. if hasattr(self, "model") and isinstance(self.model, ModelV2): self._variables = ray.experimental.tf_utils.TensorFlowVariables( [], self._sess, self.variables()) else: self._variables = ray.experimental.tf_utils.TensorFlowVariables( self._loss, self._sess) # gather update ops for any batch norm layers if not self._update_ops: self._update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) if self._update_ops: logger.info("Update ops to run on apply gradient: {}".format( self._update_ops)) with tf.control_dependencies(self._update_ops): self._apply_op = self.build_apply_op(self._optimizer, self._grads_and_vars) if log_once("loss_used"): logger.debug( "These tensors were used in the loss_fn:\n\n{}\n".format( summarize(self._loss_input_dict))) self._sess.run(tf.global_variables_initializer())
def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" self._is_training = True with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state-in properly state_in = [] for i in range(self.num_state_tensors()): state_in.append(samples["state_in_{}".format(i)]) self._state_in = state_in self._seq_lens = None if len(state_in) > 0: self._seq_lens = tf.ones( samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32) samples["seq_lens"] = self._seq_lens model_out, _ = self.model(samples, self._state_in, self._seq_lens) loss = loss_fn(self, self.model, self.dist_class, samples) variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper: def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats
def apply_gradients(self, grads): if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients( builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" self._is_training = True samples = { k: tf.convert_to_tensor(v) for k, v in samples.items() if v.dtype != np.object } with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state in properly self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS])) self._state_in = [] model_out, _ = self.model(samples, self._state_in, self._seq_lens) loss = loss_fn(self, self.model, self.dist_class, samples) variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper(object): def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats
def _debug_vars(self): if log_once("grad_vars"): for _, v in self._grads_and_vars: logger.info("Optimizing variable {}".format(v))
def _process_observations(base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks): """Record new data from the environment and prepare for policy evaluation. Returns: active_envs: set of non-terminated env ids to_eval: map of policy_id to list of agent PolicyEvalData outputs: list of metrics and samples to return from the sampler """ active_envs = set() to_eval = defaultdict(list) outputs = [] # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > max(1000, unroll_length * 10) and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check that " "that you set a horizon on your environment correctly. Note " "that in multi-agent environments, `sample_batch_size` sets " "the batch size based on environment steps, not the steps of " "individual agents, which can result in unexpectedly large " "batches.") # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics)) else: all_done = False active_envs.add(env_id) # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=agent_done, infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode if callbacks.get("on_episode_step"): callbacks["on_episode_step"]({"env": base_env, "episode": episode}) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if dones[env_id]["__all__"]: episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: outputs.append(episode.batch_builder.build_and_reset(episode)) elif all_done: # Make sure postprocessor stays within one episode episode.batch_builder.postprocess_batch_so_far(episode) if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode }) del active_episodes[env_id] resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") else: # Creates a new episode episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) policy = _get_or_raise(policies, policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), 0.0)) return active_envs, to_eval, outputs
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, unroll_length, horizon, preprocessors, obs_filters, clip_rewards, clip_actions, pack, callbacks, tf_sess=None): """This implements the common experience collection logic. Args: base_env (BaseEnv): env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. policies (dict): Map of policy ids to PolicyGraph instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. unroll_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `unroll_length` in size. clip_actions (bool): Whether to clip actions to the space range. callbacks (dict): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. Yields: rollout (SampleBatch): Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ try: if not horizon: horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: logger.debug("no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") # Pool of batch builders, which can be shared across episodes to pack # trajectory data. batch_builder_pool = [] def get_batch_builder(): if batch_builder_pool: return batch_builder_pool.pop() else: return MultiAgentSampleBatchBuilder(policies, clip_rewards) def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) if callbacks.get("on_episode_start"): callbacks["on_episode_start"]({ "env": base_env, "policy": policies, "episode": episode, }) return episode active_episodes = defaultdict(new_episode) while True: # Get observations from all ready agents unfiltered_obs, rewards, dones, infos, off_policy_actions = \ base_env.poll() if log_once("env_returns"): logger.info("Raw obs from env: {}".format( summarize(unfiltered_obs))) logger.info("Info return from env: {}".format(summarize(infos))) # Process observations and prepare for policy evaluation active_envs, to_eval, outputs = _process_observations( base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks) for o in outputs: yield o # Do batched policy eval eval_results = _do_policy_eval(tf_sess, to_eval, policies, active_episodes) # Process results and update episode state actions_to_send = _process_policy_eval_results( to_eval, eval_results, active_episodes, active_envs, off_policy_actions, policies, clip_actions) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. base_env.send_actions(actions_to_send)
def _get_loss_inputs_dict(self, batch, shuffle): """Return a feed dict from a batch. Arguments: batch (SampleBatch): batch of data to derive inputs from shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. Returns: feed dict of data """ feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req == 0 and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: if shuffle: batch.shuffle() for k, ph in self._loss_inputs: feed_dict[ph] = batch[k] return feed_dict if self._state_inputs: max_seq_len = self._max_seq_len dynamic_max = True else: max_seq_len = self._batch_divisibility_req dynamic_max = False # RNN or multi-agent case feature_keys = [k for k, v in self._loss_inputs] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens if log_once("rnn_feed_dict"): logger.info("Padded input for RNN:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, }))) return feed_dict
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape[0] = 1 return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array(ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) if self._obs_include_prev_action_reward: batch_tensors = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: batch_tensors = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in batch_tensors: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) batch_tensors[k] = placeholder if log_once("loss_init"): logger.info( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer())
def compute_gradients(self, samples): if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue ''' For Attention !!!!!!!!!!!!!!!!!!!! Build a dict for neighbor agents (also include current agent) observation ''' neighbor_batch_dic = {} neighbor_pid_list = [ pid_ for pid_ in self.traffic_light_node_dict[self.inter_num_2_id( int(pid.split('_')[1]))]['adjacency_row'] if pid_ != None ] for neighbor_pid in neighbor_pid_list: neighbor_batch_dic['policy_{}'.format( neighbor_pid)] = samples.policy_batches[ 'policy_{}'.format(neighbor_pid)] # neighbor_batch_dic[pid] = samples.policy_batches[pid] # ------------------------------------------------------------------ grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch, neighbor_batch_dic)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue ''' For Attention !!!!!!!!!!!!!!!!!!!! Build a dict for neighbor agents (also include current agent) observation ''' neighbor_batch_dic = {} neighbor_pid_list = [ pid_ for pid_ in self.traffic_light_node_dict[self.inter_num_2_id( int(pid.split('_')[1]))]['adjacency_row'] if pid_ != None ] for neighbor_pid in neighbor_pid_list: neighbor_batch_dic['policy_{}'.format( neighbor_pid)] = samples.policy_batches[ 'policy_{}'.format(neighbor_pid)] # neighbor_batch_dic[pid] = samples.policy_batches[pid] # ------------------------------------------------------------------ grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients( batch, neighbor_batch_dic)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count if log_once("grad_out"): logger.info("Compute grad info:\n\n{}\n".format( summarize(info_out))) return grad_out, info_out
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array(ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } # Add dummy things PENGZHENGHAO # for name, val in self.model.mask_placeholder_dict.items(): # shape = val.shape.as_list() # shape = [1] + [s if s is not None else 1 for s in shape] # dummy_batch[name] = \ # np.zeros(shape, dtype=val.dtype.as_numpy_dtype) if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model(self._input_dict, self._state_in, self._seq_lens) if self._obs_include_prev_action_reward: train_batch = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: train_batch = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] # When using the mask, the key of postprocessed_batch is : # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards', # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards', # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits', # 'layer0', 'layer1', 'advantages', 'value_targets']) # When not using the mask, the keys is: # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards', # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards', # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits', # 'layer0', 'layer1', 'advantages', 'value_targets']) for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder # When using the mask. At this time, the train_batch contain 17 # element. # <class 'list'>: ['prev_actions', 'prev_rewards', 'obs', 'new_obs', # 'dones', 'actions', 'rewards', 'fc_1_mask', 'fc_2_mask', # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits', # 'layer0', 'layer1', 'advantages', 'value_targets'] for i, si in enumerate(self._state_in): train_batch["state_in_{}".format(i)] = si train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict = train_batch # At this time, the accessed_keys: <class 'set'>: # {'obs', 'prev_rewards', 'value_targets', 'behaviour_logits', # 'prev_actions', 'advantages', 'action_logp', 'actions', # 'vf_preds', 'accessed_keys', 'intercepted_values'} # However, in the no-mask exp, current accessed_keys: # <class 'set'>: {'intercepted_values', 'accessed_keys'} loss = self._do_loss_init(train_batch) # after the above line, the accessed_keys: <class 'set'>: # {'advantages', 'action_logp', 'behaviour_logits', 'prev_rewards', # 'prev_actions', 'vf_preds', 'actions', 'value_targets', 'obs'} # However, in the no-mask exp, above line lead to: They are same. # but different order. # {'action_logp', 'prev_actions', 'behaviour_logits', # 'value_targets', 'obs', 'prev_rewards', 'advantages', 'vf_preds', # 'actions'} # at this time, the loss input already has: prev_actions, # prev_rewards, obs for k in sorted(train_batch.accessed_keys): # sorted train_batch.accessed_keys: <class 'list'>: [ # 'action_logp', 'actions', 'advantages', 'behaviour_logits', # 'obs', 'prev_actions', 'prev_rewards', 'value_targets', # 'vf_preds'] if k != "seq_lens" and not k.startswith("state_in_"): loss_inputs.append((k, train_batch[k])) # PENGZHENGHAO # for name, ph in self.model.mask_placeholder_dict.items(): # loss_inputs.append((name, ph)) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) self._sess.run(tf.global_variables_initializer())
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array( ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model(self._input_dict, self._state_in, self._seq_lens) if self._obs_include_prev_action_reward: train_batch = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: train_batch = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder for i, si in enumerate(self._state_in): train_batch["state_in_{}".format(i)] = si train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict = train_batch loss = self._do_loss_init(train_batch) for k in sorted(train_batch.accessed_keys): if k != "seq_lens" and not k.startswith("state_in_"): loss_inputs.append((k, train_batch[k])) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) self._sess.run(tf.global_variables_initializer())
def _process_observations(base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks, soft_horizon, no_done_at_end): """Record new data from the environment and prepare for policy evaluation. Returns: active_envs: set of non-terminated env ids to_eval: map of policy_id to list of agent PolicyEvalData outputs: list of metrics and samples to return from the sampler """ global i global tmp_dic global traffic_light_node_dict i += 1 def inter_num_2_id(num): return list(tmp_dic.keys())[list(tmp_dic.values()).index(num)] def read_traffic_light_node_dict(): path_to_read = os.path.join(record_dir, 'traffic_light_node_dict.conf') with open(path_to_read, 'r') as f: traffic_light_node_dict = eval(f.read()) print("Read traffic_light_node_dict") return traffic_light_node_dict if i <= 1: # 此处用于从配置文件读入 neighbor 情况 record_dir = base_env.envs[0].record_dir traffic_light_node_dict = base_env.envs[0].traffic_light_node_dict tmp_dic = traffic_light_node_dict['intersection_1_1'][ 'inter_id_to_index'] active_envs = set() to_eval = defaultdict(list) outputs = [] # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > max(1000, unroll_length * 10) and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check that " "that you set a horizon on your environment correctly. Note " "that in multi-agent environments, `sample_batch_size` sets " "the batch size based on environment steps, not the steps of " "individual agents, which can result in unexpectedly large " "batches.") # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {})) else: hit_horizon = False all_done = False active_envs.add(env_id) # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) # eg: "policy_0" # print(policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) ''' For Attention !!!!!!!!!!!!!!!!!!!! 这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值 ''' # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式 neighbor_pid_list = [ 'policy_{}'.format(pid_) for pid_ in traffic_light_node_dict[inter_num_2_id( int(policy_id.split('_')[1]))]['adjacency_row'] if pid_ != None ] # print(neighbor_pid_list) neighbor_obs = [] neighbor_obs.append([]) # Size: (1, 5, 15) 只有这个形式才能传入neighbor_obs (batch, 5, 15) 的 Placeholder i = 0 for neighbor_id in neighbor_pid_list: neighbor_prep_obs = _get_or_raise( preprocessors, neighbor_id).transform(raw_obs) neighbor_filtered_obs = _get_or_raise( obs_filters, neighbor_id)(neighbor_prep_obs) neighbor_obs[0].append(neighbor_filtered_obs) i += 1 neighbor_obs = np.array(neighbor_obs).reshape( (len(neighbor_pid_list), len(raw_obs))) # (5, 29) # ------------------------------------------------------------------ if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, neighbor_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode if callbacks.get("on_episode_step"): callbacks["on_episode_step"]({"env": base_env, "episode": episode}) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: outputs.append(episode.batch_builder.build_and_reset(episode)) elif all_done: # Make sure postprocessor stays within one episode episode.batch_builder.postprocess_batch_so_far(episode) if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode }) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs = agent_obs else: del active_episodes[env_id] resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # Creates a new episode if this is not async return # If reset is async, we will get its result in some future poll episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) # eg: "policy_0" policy = _get_or_raise(policies, policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) # print('policy_id' + str(policy_id)) # print('filtered_obs' + str(filtered_obs)) ''' For Attention !!!!!!!!!!!!!!!!!!!! 这里是episode终止, create a new episode 这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值 ''' # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式 neighbor_pid_list = [ 'policy_{}'.format(pid_) for pid_ in traffic_light_node_dict[inter_num_2_id( int(policy_id.split('_')[1]))]['adjacency_row'] if pid_ != None ] # print(neighbor_pid_list) neighbor_obs = [] neighbor_obs.append([]) # Size: (1, 5, 29) 只有这个形式才能传入neighbor_obs (batch, 5, 17) 的 Placeholder i = 0 for neighbor_id in neighbor_pid_list: neighbor_prep_obs = _get_or_raise( preprocessors, neighbor_id).transform(raw_obs) neighbor_filtered_obs = _get_or_raise( obs_filters, neighbor_id)(neighbor_prep_obs) neighbor_obs[0].append(neighbor_filtered_obs) i += 1 neighbor_obs = np.squeeze(np.array(neighbor_obs)) # ------------------------------------------------------------------ episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, neighbor_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), 0.0)) return active_envs, to_eval, outputs
def _process_observations(base_env, policies, policies_to_train, dead_policies, policy_config, observation_filter, tf_sess, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks, soft_horizon, no_done_at_end): #===MOD=== """Record new data from the environment and prepare for policy evaluation. Returns: active_envs: set of non-terminated env ids to_eval: map of policy_id to list of agent PolicyEvalData outputs: list of metrics and samples to return from the sampler """ active_envs = set() to_eval = defaultdict(list) outputs = [] # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > max(1000, unroll_length * 10) and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check that " "that you set a horizon on your environment correctly. Note " "that in multi-agent environments, `sample_batch_size` sets " "the batch size based on environment steps, not the steps of " "individual agents, which can result in unexpectedly large " "batches.") # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: # DEBUG # print("Trying to terminate.") # print("Dones of __all__ is set:", dones[env_id]["__all__"]) # print("Horizon hit:", episode.length >= horizon) hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {})) else: hit_horizon = False all_done = False active_envs.add(env_id) #===MOD=== additional_builders_ids = set() #===MOD=== # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): #===MOD=== policy_id, policy_constructor_tuple = episode.policy_for(agent_id) pols_tuple = generate_policies( policy_id, policy_constructor_tuple, policies, policies_to_train, dead_policies, policy_config, preprocessors, obs_filters, observation_filter, tf_sess, ) policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple #===MOD=== prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): #===MOD=== additional_builders_ids.add(agent_id) #===MOD=== episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) #===MOD=== if agent_done: # Does it make sense to remove agent id from `agent_builders`? dead_policies.add(policy_id) print("Removing agent id from agent builders: %s" % str(agent_id)) episode.batch_builder.agent_builders.pop(agent_id) if policy_id in to_eval: to_eval.pop(policy_id) # print("Popping policy id from toeval.") #===MOD=== start = time.time() #===MOD=== print("sampler.py: ids added to agent builders:\t", additional_builders_ids) # Update ``self.policy_map`` in ``MultiAgentSampleBatchBuilder``. # TODO: policies is not being pruned in this file. episode.batch_builder.policy_map = policies print("sampler.py: policies: \t", policies.keys()) #===MOD=== # Invoke the step callback after the step is logged to the episode if callbacks.get("on_episode_step"): callbacks["on_episode_step"]({"env": base_env, "episode": episode}) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: outputs.append(episode.batch_builder.build_and_reset(episode)) elif all_done: # Make sure postprocessor stays within one episode # KEYERROR episode.batch_builder.postprocess_batch_so_far(episode) if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode }) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs = agent_obs else: del active_episodes[env_id] resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # print("Executing new epsiode non-async return.") time.sleep(1) raise NotImplementedError( "Multiple episodes not supported by design.") # Creates a new episode if this is not async return # If reset is async, we will get its result in some future poll episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): #===MOD=== policy_id, policy_constructor_tuple = episode.policy_for( agent_id) # with tf_sess.as_default(): pols_tuple = generate_policies( policy_id, policy_constructor_tuple, policies, policies_to_train, dead_policies, policy_config, preprocessors, obs_filters, observation_filter, tf_sess, ) policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple #===MOD=== policy = _get_or_raise(policies, policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), 0.0)) #===MOD=== pols_tuple = (policies, preprocessors, obs_filters, policies_to_train, dead_policies) #===MOD=== #===MOD=== return active_envs, to_eval, outputs, pols_tuple
def _get_loss_inputs_dict(self, batch, neighbor_batch_dic, shuffle): """Return a feed dict from a batch. Arguments: batch (SampleBatch): batch of data to derive inputs from neighbor_batch_dic (dict, SampleBatch): batch of data for neighbor of the main policy shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. Returns: feed dict of data """ feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req == 0 and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True neighbor_list = [None] * 5 neighbor_count = 0 for k in neighbor_batch_dic: neighbor_list[neighbor_count] = k neighbor_count += 1 tmp_dic = {} # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: if shuffle: batch.shuffle() for k, ph in self._loss_inputs: ''' For Attention 这里是用于Q target的replay buffer, 从SampleBatch中整理出 neighbor_obs的信息 ''' if 'neighbor_obs' in k: tmp_dic[ph] = {} feed_dict[ph] = [] # neighbor_id = neighbor_list[int(k.split('_')[2])] # if neighbor_id is None: # continue # else: for neighbor_id in neighbor_list: tmp_dic[ph][neighbor_id] = neighbor_batch_dic[ neighbor_id]['obs'] neighbor_id = neighbor_list[0] # [neighbor, batch, feather] -> [batch, neighbor, feather] for batch_item in range(len(tmp_dic[ph][neighbor_id])): feed_dict[ph].append([]) for neighbor_id in neighbor_list: feed_dict[ph][batch_item].append( tmp_dic[ph][neighbor_id][batch_item]) feed_dict[ph] = np.array(feed_dict[ph]) # feed_dict[ph] = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 1)))(tmp_dic[ph]) # ---------------------------------------------------------------- else: feed_dict[ph] = batch[k] return feed_dict if self._state_inputs: max_seq_len = self._max_seq_len dynamic_max = True else: max_seq_len = self._batch_divisibility_req dynamic_max = False # RNN or multi-agent case feature_keys = [k for k, v in self._loss_inputs] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens if log_once("rnn_feed_dict"): logger.info("Padded input for RNN:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, }))) return feed_dict
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape[0] = 1 return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array(ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) if self._obs_include_prev_action_reward: batch_tensors = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: batch_tensors = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in batch_tensors: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) batch_tensors[k] = placeholder if log_once("loss_init"): logger.info( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) # XXX experimental support for automatically eagerifying the loss. # The main limitation right now is that TF doesn't support mixing eager # and non-eager tensors, so losses that read non-eager tensors through # `policy` need to use `policy.convert_to_eager(tensor)`. if self.config["use_eager"]: if not self.model: raise ValueError("eager not implemented in this case") graph_tensors = list(self._needs_eager_conversion) def gen_loss(model_outputs, *args): # fill in the batch tensor dict with eager ensors eager_inputs = dict( zip([k for (k, v) in loss_inputs], args[:len(loss_inputs)])) # fill in the eager versions of all accessed graph tensors self._eager_tensors = dict( zip(graph_tensors, args[len(loss_inputs):])) # patch the action dist to use eager mode tensors self.action_dist.inputs = model_outputs return self._loss_fn(self, eager_inputs) # TODO(ekl) also handle the stats funcs loss = tf.py_function( gen_loss, # cast works around TypeError: Cannot convert provided value # to EagerTensor. Provided value: 0.0 Requested dtype: int64 [self.model.outputs] + [tf.cast(v, tf.float32) for (k, v) in loss_inputs] + [tf.cast(t, tf.float32) for t in graph_tensors], tf.float32) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer())
def load_data(self, sess, inputs, state_inputs): """Bulk loads the specified inputs into device memory. The shape of the inputs must conform to the shapes of the input placeholders this optimizer was constructed with. The data is split equally across all the devices. If the data is not evenly divisible by the batch size, excess data will be discarded. Args: sess: TensorFlow session. inputs: List of arrays matching the input placeholders, of shape [BATCH_SIZE, ...]. state_inputs: List of RNN input arrays. These arrays have size [BATCH_SIZE / MAX_SEQ_LEN, ...]. Returns: The number of tuples loaded per device. """ if log_once("load_data"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize({ "placeholders": self.loss_inputs, "inputs": inputs, "state_inputs": state_inputs }))) feed_dict = {} assert len(self.loss_inputs) == len(inputs + state_inputs), \ (self.loss_inputs, inputs, state_inputs) # Let's suppose we have the following input data, and 2 devices: # 1 2 3 4 5 6 7 <- state inputs shape # A A A B B B C C C D D D E E E F F F G G G <- inputs shape # The data is truncated and split across devices as follows: # |---| seq len = 3 # |---------------------------------| seq batch size = 6 seqs # |----------------| per device batch size = 9 tuples if len(state_inputs) > 0: smallest_array = state_inputs[0] seq_len = len(inputs[0]) // len(state_inputs[0]) self._loaded_max_seq_len = seq_len else: smallest_array = inputs[0] self._loaded_max_seq_len = 1 sequences_per_minibatch = (self.max_per_device_batch_size // self._loaded_max_seq_len * len(self.devices)) if sequences_per_minibatch < 1: logger.warn( ("Target minibatch size is {}, however the rollout sequence " "length is {}, hence the minibatch size will be raised to " "{}.").format(self.max_per_device_batch_size, self._loaded_max_seq_len, self._loaded_max_seq_len * len(self.devices))) sequences_per_minibatch = 1 if len(smallest_array) < sequences_per_minibatch: # Dynamically shrink the batch size if insufficient data sequences_per_minibatch = make_divisible_by( len(smallest_array), len(self.devices)) if log_once("data_slicing"): logger.info( ("Divided {} rollout sequences, each of length {}, among " "{} devices.").format(len(smallest_array), self._loaded_max_seq_len, len(self.devices))) if sequences_per_minibatch < len(self.devices): raise ValueError( "Must load at least 1 tuple sequence per device. Try " "increasing `sgd_minibatch_size` or reducing `max_seq_len` " "to ensure that at least one sequence fits per device.") self._loaded_per_device_batch_size = (sequences_per_minibatch // len(self.devices) * self._loaded_max_seq_len) if len(state_inputs) > 0: # First truncate the RNN state arrays to the sequences_per_minib. state_inputs = [ make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs ] # Then truncate the data inputs to match inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs] assert len(state_inputs[0]) * seq_len == len(inputs[0]), \ (len(state_inputs[0]), sequences_per_minibatch, seq_len, len(inputs[0])) for ph, arr in zip(self.loss_inputs, inputs + state_inputs): feed_dict[ph] = arr truncated_len = len(inputs[0]) else: for ph, arr in zip(self.loss_inputs, inputs + state_inputs): truncated_arr = make_divisible_by(arr, sequences_per_minibatch) feed_dict[ph] = truncated_arr truncated_len = len(truncated_arr) sess.run([t.init_op for t in self._towers], feed_dict=feed_dict) self.num_tuples_loaded = truncated_len tuples_per_device = truncated_len // len(self.devices) assert tuples_per_device > 0, "No data loaded?" assert tuples_per_device % self._loaded_per_device_batch_size == 0 return tuples_per_device
def _env_runner(base_env, extra_batch_callback, policies, policies_to_train, policy_config, observation_filter, policy_mapping_fn, unroll_length, horizon, preprocessors, obs_filters, clip_rewards, clip_actions, pack, callbacks, tf_sess, perf_stats, soft_horizon, no_done_at_end): #===MOD=== """This implements the common experience collection logic. Args: base_env (BaseEnv): env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. policies (dict): Map of policy ids to Policy instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. unroll_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `unroll_length` in size. clip_actions (bool): Whether to clip actions to the space range. callbacks (dict): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. perf_stats (PerfStats): Record perf stats into this object. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. Yields: rollout (SampleBatch): Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ try: if not horizon: horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: logger.debug("no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") # Pool of batch builders, which can be shared across episodes to pack # trajectory data. batch_builder_pool = [] def get_batch_builder(): if batch_builder_pool: return batch_builder_pool.pop() else: #===MOD=== # We use the below. #===MOD=== return MultiAgentSampleBatchBuilder( policies, clip_rewards, callbacks.get("on_postprocess_traj")) def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) if callbacks.get("on_episode_start"): callbacks["on_episode_start"]({ "env": base_env, "policy": policies, "episode": episode, }) return episode active_episodes = defaultdict(new_episode) #===MOD=== dead_policies = set() #===MOD=== while True: perf_stats.iters += 1 t0 = time.time() # Get observations from all ready agents unfiltered_obs, rewards, dones, infos, off_policy_actions = \ base_env.poll() perf_stats.env_wait_time += time.time() - t0 if log_once("env_returns"): logger.info("Raw obs from env: {}".format( summarize(unfiltered_obs))) logger.info("Info return from env: {}".format(summarize(infos))) # Process observations and prepare for policy evaluation t1 = time.time() #===MOD=== # DEBUG # print("Iterations:", perf_stats.iters) """ Added arguments: policies_to_train, dead_policies, policy_config, observation_filter, tf_sess, """ active_envs, to_eval, outputs, pols_tuple = _process_observations( base_env, policies, policies_to_train, dead_policies, policy_config, observation_filter, tf_sess, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks, soft_horizon, no_done_at_end) policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple #===MOD=== perf_stats.processing_time += time.time() - t1 for o in outputs: yield o # Do batched policy eval t2 = time.time() #===MOD=== for policy_id in dead_policies: assert policy_id not in to_eval # _do_policy_eval(tf_sess, ... -> ..._eval(None, ... eval_results = _do_policy_eval(None, to_eval, policies, active_episodes) #===MOD=== perf_stats.inference_time += time.time() - t2 # DEBUG # print("sampler.py: t2: %fs" % (time.time() - t2)) # Process results and update episode state t3 = time.time() actions_to_send = _process_policy_eval_results(to_eval, eval_results, active_episodes, active_envs, off_policy_actions, policies, clip_actions) perf_stats.processing_time += time.time() - t3 # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. t4 = time.time() base_env.send_actions(actions_to_send) perf_stats.env_wait_time += time.time() - t4