def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self._build_compute_actions(builder, obs_batch, state_batches, prev_action_batch, prev_reward_batch) return builder.get(fetches)
def apply_gradients(self, grads): if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients( builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None): if self._log_likelihood is None: raise ValueError("Cannot compute log-prob/likelihood w/o a " "self._log_likelihood op!") # Do the forward pass through the model to capture the parameters # for the action distribution, then do a logp on that distribution. builder = TFRunBuilder(self._sess, "compute_log_likelihoods") # Feed actions (for which we want logp values) into graph. builder.add_feed_dict({self._action_input: actions}) # Feed observations. builder.add_feed_dict({self._obs_input: obs_batch}) # Internal states. state_batches = state_batches or [] if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". format(self._state_inputs, state_batches)) builder.add_feed_dict( {k: v for k, v in zip(self._state_inputs, state_batches)}) if state_batches: builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) # Prev-a and r. if self._prev_action_input is not None and \ prev_action_batch is not None: builder.add_feed_dict({self._prev_action_input: prev_action_batch}) if self._prev_reward_input is not None and \ prev_reward_batch is not None: builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) # Fetch the log_likelihoods output and return. fetches = builder.add_fetches([self._log_likelihood]) return builder.get(fetches)[0]
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicyGraph.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) return eval_results
def compute_actions_from_input_dict( self, input_dict: Union[SampleBatch, Dict[str, TensorType]], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["Episode"]] = None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep # Switch off is_training flag in our batch. if isinstance(input_dict, SampleBatch): input_dict.set_training(False) else: # Deprecated dict input. input_dict["is_training"] = False builder = TFRunBuilder(self.get_session(), "compute_actions_from_input_dict") obs_batch = input_dict[SampleBatch.OBS] to_fetch = self._build_compute_actions( builder, input_dict=input_dict, explore=explore, timestep=timestep ) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += ( len(obs_batch) if isinstance(obs_batch, list) else len(input_dict) if isinstance(input_dict, SampleBatch) else obs_batch.shape[0] ) return fetched
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicyGraph.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) return eval_results
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: assert self.loss_initialized() # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) builder = TFRunBuilder(self.get_session(), "learn_on_batch") # Callback handling. learn_stats = {} self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch, result=learn_stats ) fetches = self._build_learn_on_batch(builder, postprocessed_batch) stats = builder.get(fetches) stats.update( { "custom_metrics": learn_stats, NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, } ) return stats
def compute_apply(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_apply") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid]._build_compute_apply( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid].compute_apply(batch)) return info_out else: grad_fetch, apply_fetch = ( self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch
def compute_apply(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_apply") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid]._build_compute_apply( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid].compute_apply(batch)) return info_out else: grad_fetch, apply_fetch = ( self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch
def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): explore = explore if explore is not None else self.config["explore"] builder = TFRunBuilder(self._sess, "compute_actions") fetches = self._build_compute_actions( builder, obs_batch, state_batches, prev_action_batch, prev_reward_batch, explore=explore, timestep=timestep if timestep is not None else self.global_timestep) # Execute session run to get action (and other fetches). return builder.get(fetches)
def compute_gradients(self, samples): if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count return grad_out, info_out
def compute_gradients(self, samples): if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count return grad_out, info_out
def compute_apply(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_apply") fetches = self._build_compute_apply(builder, postprocessed_batch) return builder.get(fetches)
def _env_runner(async_vector_env, extra_batch_callback, policies, policy_mapping_fn, unroll_length, horizon, obs_filters, clip_rewards, pack, tf_sess=None): """This implements the common experience collection logic. Args: async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv. extra_batch_callback (fn): function to send extra batch data to. policies (dict): Map of policy ids to PolicyGraph instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. unroll_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `unroll_length` in size. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. Yields: rollout (SampleBatch): Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ try: if not horizon: horizon = ( async_vector_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: print("*** WARNING ***: no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") # Pool of batch builders, which can be shared across episodes to pack # trajectory data. batch_builder_pool = [] def get_batch_builder(): if batch_builder_pool: return batch_builder_pool.pop() else: return MultiAgentSampleBatchBuilder(policies, clip_rewards) def new_episode(): return MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) active_episodes = defaultdict(new_episode) while True: # Get observations from all ready agents unfiltered_obs, rewards, dones, infos, off_policy_actions = \ async_vector_env.poll() # Map of policy_id to list of PolicyEvalData to_eval = defaultdict(list) # Map of env_id -> agent_id -> action replies actions_to_send = defaultdict(dict) # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: all_done = True atari_metrics = _fetch_atari_metrics(async_vector_env) if atari_metrics is not None: for m in atari_metrics: yield m else: yield RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards)) else: all_done = False # At least send an empty dict if not done actions_to_send[env_id] = {} # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, episode.rnn_state_for(agent_id))) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) # Record transition info if applicable if last_observation is not None and \ infos[env_id][agent_id].get("training_enabled", True): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], dones=agent_done, infos=infos[env_id][agent_id], new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: yield episode.batch_builder.build_and_reset() elif all_done: # Make sure postprocessor stays within one episode episode.batch_builder.postprocess_batch_so_far() if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) del active_episodes[env_id] resetted_obs = async_vector_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list assert horizon == float("inf"), \ "Setting episode horizon requires reset() support." else: # Creates a new episode episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, episode.rnn_state_for(agent_id))) # Batch eval policy actions if possible if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None eval_results = {} rnn_in_cols = {} for policy_id, eval_data in to_eval.items(): rnn_in = _to_column_format([t.rnn_state for t in eval_data]) rnn_in_cols[policy_id] = rnn_in policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicyGraph.compute_actions.__code__): pending_fetches[policy_id] = policy.build_compute_actions( builder, [t.obs for t in eval_data], rnn_in, is_training=True) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in, is_training=True, episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) # Record the policy eval results for policy_id, eval_data in to_eval.items(): actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] if len(rnn_in_cols[policy_id]) != len(rnn_out_cols): raise ValueError( "Length of RNN in did not match RNN out, got: " "{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols)) # Add RNN state info for f_i, column in enumerate(rnn_in_cols[policy_id]): pi_info_cols["state_in_{}".format(f_i)] = column for f_i, column in enumerate(rnn_out_cols): pi_info_cols["state_out_{}".format(f_i)] = column # Save output rows for i, action in enumerate(actions): env_id = eval_data[i].env_id agent_id = eval_data[i].agent_id actions_to_send[env_id][agent_id] = action episode = active_episodes[env_id] episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) episode._set_last_pi_info( agent_id, {k: v[i] for k, v in pi_info_cols.items()}) if env_id in off_policy_actions and \ agent_id in off_policy_actions[env_id]: episode._set_last_action( agent_id, off_policy_actions[env_id][agent_id]) else: episode._set_last_action(agent_id, action) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. async_vector_env.send_actions(dict(actions_to_send))
def apply_gradients(self, gradients: ModelGradients) -> None: assert self.loss_initialized() builder = TFRunBuilder(self.get_session(), "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches)
def learn_on_batch(self, postprocessed_batch): assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches)
def apply_gradients(self, gradients): assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches)
def _do_policy_eval( *, to_eval: Dict[PolicyID, List[PolicyEvalData]], policies: Dict[PolicyID, Policy], policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID], sample_collector, active_episodes: Dict[str, MultiAgentEpisode], tf_sess: Optional["tf.Session"] = None, ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: """Call compute_actions on collected episode/model data to get next action. Args: to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs to lists of PolicyEvalData objects (items in these lists will be the batch's items for the model forward pass). policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy obj. sample_collector (SampleCollector): The SampleCollector object to use. tf_sess (Optional[tf.Session]): Optional tensorflow session to use for batching TF policy evaluations. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results: Dict[PolicyID, TensorStructType] = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches: Dict[PolicyID, Any] = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): # In case the policyID has been removed from this worker, we need to # re-assign policy_id and re-lookup the Policy object to use. try: policy: Policy = _get_or_raise(policies, policy_id) except ValueError: policy_id = policy_mapping_fn(eval_data[0].agent_id, active_episodes[eval_data[0].env_id]) policy: Policy = _get_or_raise(policies, policy_id) input_dict = sample_collector.get_inference_input_dict(policy_id) eval_results[policy_id] = \ policy.compute_actions_from_input_dict( input_dict, timestep=policy.global_timestep, episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: # type: PolicyID, Tuple[TensorStructType, StateBatch, dict] for pid, v in pending_fetches.items(): eval_results[pid] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def apply_gradients(self, gradients): builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) return builder.get(fetches)
def learn_on_batch(self, postprocessed_batch): assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches)
def compute_gradients(self, postprocessed_batch): assert self.loss_initialized() builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def _do_policy_eval(*, to_eval, policies, active_episodes, tf_sess=None): """Call compute_actions on collected episode/model data to get next action. Args: tf_sess (Optional[tf.Session]): Optional tensorflow session to use for batching TF policy evaluations. to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to lists of PolicyEvalData objects. policies (Dict[str,Policy]): Mapping from policy ID to Policy obj. active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in = [t.rnn_state for t in eval_data] policy = _get_or_raise(policies, policy_id) # If tf (non eager) AND TFPolicy's compute_action method has not been # overridden -> Use `policy._build_compute_actions()`. if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): obs_batch = [t.obs for t in eval_data] state_batches = _to_column_format(rnn_in) # TODO(ekl): how can we make info batch available to TF code? prev_action_batch = [t.prev_action for t in eval_data] prev_reward_batch = [t.prev_reward for t in eval_data] pending_fetches[policy_id] = policy._build_compute_actions( builder, obs_batch=obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, timestep=policy.global_timestep) else: rnn_in_cols = [ np.stack([row[i] for row in rnn_in]) for i in range(len(rnn_in[0])) ] eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data], timestep=policy.global_timestep) if builder: for pid, v in pending_fetches.items(): eval_results[pid] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def apply_gradients(self, gradients): builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches)
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in = [t.rnn_state for t in eval_data] policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): obs_batch = [t.obs for t in eval_data] state_batches = _to_column_format(rnn_in) # TODO(ekl): how can we make info batch available to TF code? obs_batch = [t.obs for t in eval_data] prev_action_batch = [t.prev_action for t in eval_data] prev_reward_batch = [t.prev_reward for t in eval_data] pending_fetches[policy_id] = policy._build_compute_actions( builder, obs_batch=obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, timestep=policy.global_timestep) else: # TODO(sven): Does this work for LSTM torch? rnn_in_cols = [ np.stack([row[i] for row in rnn_in]) for i in range(len(rnn_in[0])) ] eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data], timestep=policy.global_timestep) if builder: for pid, v in pending_fetches.items(): eval_results[pid] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def _do_policy_eval( *, to_eval: Dict[PolicyID, List[PolicyEvalData]], policies: Dict[PolicyID, Policy], active_episodes: Dict[str, MultiAgentEpisode], tf_sess=None, _use_trajectory_view_api=False ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: """Call compute_actions on collected episode/model data to get next action. Args: to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs to lists of PolicyEvalData objects (items in these lists will be the batch's items for the model forward pass). policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy obj. active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. tf_sess (Optional[tf.Session]): Optional tensorflow session to use for batching TF policy evaluations. _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` procedure to collect samples. Default: False. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results: Dict[PolicyID, TensorStructType] = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches: Dict[PolicyID, Any] = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) # type: PolicyID, PolicyEvalData for policy_id, eval_data in to_eval.items(): rnn_in: List[List[Any]] = [t.rnn_state for t in eval_data] policy: Policy = _get_or_raise(policies, policy_id) # If tf (non eager) AND TFPolicy's compute_action method has not been # overridden -> Use `policy._build_compute_actions()`. if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): obs_batch: List[EnvObsType] = [t.obs for t in eval_data] state_batches: StateBatch = _to_column_format(rnn_in) # TODO(ekl): how can we make info batch available to TF code? prev_action_batch = [t.prev_action for t in eval_data] prev_reward_batch = [t.prev_reward for t in eval_data] pending_fetches[policy_id] = policy._build_compute_actions( builder, obs_batch=obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, timestep=policy.global_timestep) else: rnn_in_cols: StateBatch = [ np.stack([row[i] for row in rnn_in]) for i in range(len(rnn_in[0])) ] eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data], timestep=policy.global_timestep) if builder: # type: PolicyID, Tuple[TensorStructType, StateBatch, dict] for pid, v in pending_fetches.items(): eval_results[pid] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def compute_gradients(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def apply_gradients(self, gradients): assert self.loss_initialized() builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches)
def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches)
def compute_gradients(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def learn_on_batch(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches)
def learn_on_batch(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches)
def compute_apply(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_apply") fetches = self._build_compute_apply(builder, postprocessed_batch) return builder.get(fetches)