def compute_actions(self, obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs): explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self._sess, "compute_actions") to_fetch = self._build_compute_actions( builder, obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \ else obs_batch.shape[0] return fetched
def compute_gradients(self, samples): if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count if log_once("grad_out"): logger.info("Compute grad info:\n\n{}\n".format( summarize(info_out))) return grad_out, info_out
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch) else: info_out[pid] = policy.learn_on_batch(batch) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch( samples) if log_once("learn_out"): logger.debug("Training out:\n\n{}\n".format(summarize(info_out))) return info_out
def apply_gradients(self, grads): """Applies the given gradients to this worker's weights. Examples: >>> samples = worker.sample() >>> grads, info = worker.compute_gradients(samples) >>> worker.apply_gradients(grads) """ if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients(builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def compute_actions_from_input_dict( self, input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, **kwargs) -> \ Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict") obs_batch = input_dict[SampleBatch.OBS] to_fetch = self._build_compute_actions(builder, input_dict=input_dict, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \ else obs_batch.shape[0] return fetched
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in = [t.rnn_state for t in eval_data] policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): rnn_in_cols = _to_column_format(rnn_in) # TODO(ekl): how can we make info batch available to TF code? # TODO(sven): Return dict from _build_compute_actions. # it's becoming more and more unclear otherwise, what's where in # the return tuple. pending_fetches[policy_id] = policy._build_compute_actions( builder, obs_batch=[t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], timestep=policy.global_timestep) else: # TODO(sven): Does this work for LSTM torch? rnn_in_cols = [ np.stack([row[i] for row in rnn_in]) for i in range(len(rnn_in[0])) ] eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], state_batches=rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data], timestep=policy.global_timestep) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def compute_actions(self, obs_batch, state_batches=None, is_training=False): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self.build_compute_actions(builder, obs_batch, state_batches, is_training) return builder.get(fetches)
def compute_gradients( self, postprocessed_batch: SampleBatch) -> \ Tuple[ModelGradients, Dict[str, TensorType]]: assert self.loss_initialized() builder = TFRunBuilder(self.get_session(), "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def compute_rnn_state_out(self): if self._episode_buffer: if not self._rnn_state_out: builder = TFRunBuilder(self._sess, "compute_rnn_state_out") fetches = self.build_compute_rnn_state_out(builder) self._rnn_state_out = builder.get(fetches) return self._rnn_state_out else: return self.model.rnn_state_out_init
def compute_gradients( self, postprocessed_batch: SampleBatch ) -> Tuple[ModelGradients, Dict[str, TensorType]]: assert self.loss_initialized() # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) builder = TFRunBuilder(self.get_session(), "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def _do_policy_eval( *, to_eval: Dict[PolicyID, List[PolicyEvalData]], policies: Dict[PolicyID, Policy], sample_collector, active_episodes: Dict[str, MultiAgentEpisode], tf_sess: Optional["tf.Session"] = None, ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: """Call compute_actions on collected episode/model data to get next action. Args: to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs to lists of PolicyEvalData objects (items in these lists will be the batch's items for the model forward pass). policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy obj. sample_collector (SampleCollector): The SampleCollector object to use. tf_sess (Optional[tf.Session]): Optional tensorflow session to use for batching TF policy evaluations. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results: Dict[PolicyID, TensorStructType] = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches: Dict[PolicyID, Any] = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): policy: Policy = _get_or_raise(policies, policy_id) input_dict = sample_collector.get_inference_input_dict(policy_id) eval_results[policy_id] = \ policy.compute_actions_from_input_dict( input_dict, timestep=policy.global_timestep, episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: # type: PolicyID, Tuple[TensorStructType, StateBatch, dict] for pid, v in pending_fetches.items(): eval_results[pid] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def compute_inner_gradients(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_inner_gradients") self._before_compute_grads() self._grads = self._inner_grads self._loss_inputs = self._inner_loss_inputs self._loss_input_dict = self._inner_loss_input_dict self.stats_fetches = self.a3c_stats_fetches fetches = self.build_compute_gradients(builder, postprocessed_batch) results = builder.get(fetches) self._after_compute_grads() return results
def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, episodes=None): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self.build_compute_actions(builder, obs_batch, state_batches, prev_action_batch, prev_reward_batch) return builder.get(fetches)
def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self._build_compute_actions(builder, obs_batch, state_batches, prev_action_batch, prev_reward_batch) return builder.get(fetches)
def compute_actions( self, obs_batch, state_batches = None, prev_action_batch = None, prev_reward_batch = None, info_batch = None, episodes = None, explore = None, timestep = None, **kwargs): """ method used in PPOTFPolicy but edited to handle dict inputs at runtime (it is handled at training by existing rllib code, but not for using already trained model) """ explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self._sess, "compute_actions") if type(obs_batch) is dict: some_batch_data = list(obs_batch.values())[0] obs_batch_len = len(some_batch_data) if isinstance(some_batch_data, list) \ else some_batch_data.shape[0] flattened_obs = [] for k in self.observation_space.original_space.spaces.keys(): if k in obs_batch: obs = np.array(obs_batch[k]) flattened_obs.append(obs.reshape(obs_batch_len, np.prod(obs.shape[1:]))) obs_batch = np.concatenate(flattened_obs, axis=-1) else: obs_batch_len = len(obs_batch) if isinstance(obs_batch, list) \ else obs_batch.shape[0] obs_batch = np.array(obs_batch) to_fetch = self._build_compute_actions( builder, obs_batch=obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += obs_batch_len return fetched
def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") # Callback handling. learn_stats = {} self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch, result=learn_stats) fetches = self._build_learn_on_batch(builder, postprocessed_batch) stats = builder.get(fetches) stats.update({"custom_metrics": learn_stats}) return stats
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None if log_once("compute_actions_input"): logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) if log_once("compute_actions_result"): logger.info("Outputs of compute_actions():\n\n{}\n".format( summarize(eval_results))) return eval_results
def apply_gradients(self, grads): if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid].build_apply_gradients(builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def apply_gradients(self, grads): if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients( builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes, clip_actions): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicyGraph.compute_actions.__code__): pending_fetches[policy_id] = policy.build_compute_actions( builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) if clip_actions: for policy_id, results in eval_results.items(): policy = _get_or_raise(policies, policy_id) actions, rnn_out_cols, pi_info_cols = results eval_results[policy_id] = (_clip_actions(actions, policy.action_space), rnn_out_cols, pi_info_cols) return eval_results
def compute_gradients( self, samples: SampleBatchType) -> Tuple[ModelGradients, dict]: """Returns a gradient computed w.r.t the specified samples. Returns: (grads, info): A list of gradients that can be applied on a compatible worker. In the multi-agent case, returns a dict of gradients keyed by policy ids. An info dictionary of extra metadata is also returned. Examples: >>> batch = worker.sample() >>> grads, info = worker.compute_gradients(samples) """ if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count if log_once("grad_out"): logger.info("Compute grad info:\n\n{}\n".format( summarize(info_out))) return grad_out, info_out
def compute_apply(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_apply") for pid, batch in samples.policy_batches.items(): info_out[pid], _ = ( self.policy_map[pid].build_compute_apply( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): info_out[pid], _ = ( self.policy_map[pid].compute_apply(batch)) return info_out else: grad_fetch, apply_fetch = ( self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch
def compute_gradients(self, samples): if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): grad_out[pid], info_out[pid] = ( self.policy_map[pid].build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) return grad_out, info_out else: return self.policy_map[DEFAULT_POLICY_ID].compute_gradients( samples)
def apply_gradients(self, grads): if log_once("apply_gradients"): logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { pid: self.policy_map[pid]._build_apply_gradients( builder, grad) for pid, grad in grads.items() } return {k: builder.get(v) for k, v in outputs.items()} else: return { pid: self.policy_map[pid].apply_gradients(g) for pid, g in grads.items() } else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def compute_actions( self, obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List["Episode"]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, ): explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep builder = TFRunBuilder(self.get_session(), "compute_actions") input_dict = {SampleBatch.OBS: obs_batch, "is_training": False} if state_batches: for i, s in enumerate(state_batches): input_dict[f"state_in_{i}"] = s if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch to_fetch = self._build_compute_actions( builder, input_dict=input_dict, explore=explore, timestep=timestep ) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += ( len(obs_batch) if isinstance(obs_batch, list) else tree.flatten(obs_batch)[0].shape[0] ) return fetched
def learn_on_batch(self, samples: SampleBatchType) -> dict: """Update policies based on the given batch. This is the equivalent to apply_gradients(compute_gradients(samples)), but can be optimized to avoid pulling gradients into CPU memory. Returns: info: dictionary of extra metadata from compute_gradients(). Examples: >>> batch = worker.sample() >>> worker.learn_on_batch(samples) """ if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch) else: info_out[pid] = policy.learn_on_batch(batch) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: info_out = { DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID] .learn_on_batch(samples) } if log_once("learn_out"): logger.debug("Training out:\n\n{}\n".format(summarize(info_out))) return info_out
def learn_on_batch(self, samples): if log_once("learn_on_batch"): logger.info( "Training on concatenated sample batches:\n\n{}\n".format( summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") else: builder = None for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue policy = self.policy_map[pid] if builder and hasattr(policy, "_build_learn_on_batch"): to_fetch[pid] = policy._build_learn_on_batch( builder, batch) else: info_out[pid] = policy.learn_on_batch(batch) info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: learn_on_batch_outputs = self.policy_map[ DEFAULT_POLICY_ID].learn_on_batch(samples) if isinstance(learn_on_batch_outputs, tuple): info_out, beholder_arrays = learn_on_batch_outputs else: info_out, beholder_arrays = learn_on_batch_outputs, {} if self.policy_config["evaluation_config"]["beholder"]: with self.tf_sess.graph.as_default(): if self._beholder is None: self._beholder = Beholder(self.io_context.log_dir) self._beholder.update(self.tf_sess, arrays=beholder_arrays or None) if log_once("learn_out"): logger.info("Training output:\n\n{}\n".format(summarize(info_out))) return info_out
def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: assert self.loss_initialized() # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) builder = TFRunBuilder(self.get_session(), "learn_on_batch") # Callback handling. learn_stats = {} self.callbacks.on_learn_on_batch(policy=self, train_batch=postprocessed_batch, result=learn_stats) fetches = self._build_learn_on_batch(builder, postprocessed_batch) stats = builder.get(fetches) stats.update({ "custom_metrics": learn_stats, NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, }) return stats
def compute_actions_from_input_dict( self, input_dict: Union[SampleBatch, Dict[str, TensorType]], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["Episode"]] = None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep # Switch off is_training flag in our batch. if isinstance(input_dict, SampleBatch): input_dict.set_training(False) else: # Deprecated dict input. input_dict["is_training"] = False builder = TFRunBuilder(self.get_session(), "compute_actions_from_input_dict") obs_batch = input_dict[SampleBatch.OBS] to_fetch = self._build_compute_actions( builder, input_dict=input_dict, explore=explore, timestep=timestep ) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += ( len(obs_batch) if isinstance(obs_batch, list) else len(input_dict) if isinstance(input_dict, SampleBatch) else obs_batch.shape[0] ) return fetched
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): """Call compute actions on observation batches to get next actions. Returns: eval_results: dict of policy to compute_action() outputs. """ eval_results = {} if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") pending_fetches = {} else: builder = None for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicyGraph.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data]) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) return eval_results
def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): explore = explore if explore is not None else self.config["explore"] builder = TFRunBuilder(self._sess, "compute_actions") fetches = self._build_compute_actions( builder, obs_batch, state_batches, prev_action_batch, prev_reward_batch, explore=explore, timestep=timestep if timestep is not None else self.global_timestep) # Execute session run to get action (and other fetches). return builder.get(fetches)
def compute_apply(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_apply") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid]._build_compute_apply( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid].compute_apply(batch)) return info_out else: grad_fetch, apply_fetch = ( self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch
def learn_on_batch(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid]._build_learn_on_batch( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue info_out[pid], _ = ( self.policy_map[pid].learn_on_batch(batch)) return info_out else: grad_fetch, apply_fetch = ( self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples)) return grad_fetch
def compute_gradients(self, samples): if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "compute_gradients") for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count return grad_out, info_out
def learn_on_batch(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches)
def compute_apply(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_apply") fetches = self._build_compute_apply(builder, postprocessed_batch) return builder.get(fetches)
def compute_gradients(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def apply_gradients(self, gradients): builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) return builder.get(fetches)