def record_env_wrapper(env, record_env, log_dir, policy_config): if record_env: path_ = record_env if isinstance(record_env, str) else log_dir # Relative path: Add logdir here, otherwise, this would # not work for non-local workers. if not os.path.isabs(path_): path_ = os.path.join(log_dir, path_) print(f"Setting the path for recording to {path_}") wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \ else wrappers.Monitor wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True) env = wrapper_cls( env, path_, resume=True, force=True, video_callable=lambda _: True, mode="evaluation" if policy_config["in_evaluation"] else "training") return env
def build_trainer( name: str, *, default_config: Optional[TrainerConfigDict] = None, validate_config: Optional[Callable[[TrainerConfigDict], None]] = None, default_policy: Optional[Type[Policy]] = None, get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[Policy]]]] = None, validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, before_init: Optional[Callable[[Trainer], None]] = None, after_init: Optional[Callable[[Trainer], None]] = None, before_evaluate_fn: Optional[Callable[[Trainer], None]] = None, mixins: Optional[List[type]] = None, execution_plan: Optional[ Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]]] = default_execution_plan ) -> Type[Trainer]: """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_policy 2. Worker setup: before_init, execution_plan 3. Post setup: after_init Args: name (str): name of the trainer (e.g., "PPO") default_config (Optional[TrainerConfigDict]): The default config dict of the algorithm, otherwise uses the Trainer default config. validate_config (Optional[Callable[[TrainerConfigDict], None]]): Optional callable that takes the config to check for correctness. It may mutate the config as needed. default_policy (Optional[Type[Policy]]): The default Policy class to use if `get_policy_class` returns None. get_policy_class (Optional[Callable[ TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable that takes a config and returns the policy class or None. If None is returned, will use `default_policy` (which must be provided then). validate_env (Optional[Callable[[EnvType, EnvContext], None]]): Optional callable to validate the generated environment (only on worker=0). before_init (Optional[Callable[[Trainer], None]]): Optional callable to run before anything is constructed inside Trainer (Workers with Policies, execution plan, etc..). Takes the Trainer instance as argument. after_init (Optional[Callable[[Trainer], None]]): Optional callable to run at the end of trainer init (after all Workers and the exec. plan have been constructed). Takes the Trainer instance as argument. before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]]]): Optional callable that sets up the distributed execution workflow. Returns: Type[Trainer]: A Trainer sub-class configured by the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy_class = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config: TrainerConfigDict, env_creator: Callable[[EnvConfigDict], EnvType]): # No `get_policy_class` function. if get_policy_class is None: # Default_policy must be provided (unless in multi-agent mode, # where each policy can have its own default policy class. if not config["multiagent"]["policies"]: assert default_policy is not None self._policy_class = default_policy # Query the function for a class to use. else: self._policy_class = get_policy_class(config) # If None returned, use default policy (must be provided). if self._policy_class is None: assert default_policy is not None self._policy_class = default_policy if before_init: before_init(self) # Creating all workers (excluding evaluation workers). self.workers = self._make_workers( env_creator=env_creator, validate_env=validate_env, policy_class=self._policy_class, config=config, num_workers=self.config["num_workers"]) self.execution_plan = execution_plan self.train_exec_impl = execution_plan(self.workers, config) if after_init: after_init(self) @override(Trainer) def step(self): # self._iteration gets incremented after this function returns, # meaning that e. g. the first time this function is called, # self._iteration will be 0. evaluate_this_iter = \ self.config["evaluation_interval"] and \ (self._iteration + 1) % self.config["evaluation_interval"] == 0 # No evaluation necessary. if not evaluate_this_iter: res = next(self.train_exec_impl) # We have to evaluate in this training iteration. else: # No parallelism. if not self.config["evaluation_parallel_to_training"]: res = next(self.train_exec_impl) # Kick off evaluation-loop (and parallel train() call, # if requested). # Parallel eval + training. if self.config["evaluation_parallel_to_training"]: with concurrent.futures.ThreadPoolExecutor() as executor: eval_future = executor.submit(self.evaluate) res = next(self.train_exec_impl) evaluation_metrics = eval_future.result() # Sequential: train (already done above), then eval. else: evaluation_metrics = self.evaluate() assert isinstance(evaluation_metrics, dict), \ "_evaluate() needs to return a dict." res.update(evaluation_metrics) # Check `env_task_fn` for possible update of the env's task. if self.config["env_task_fn"] is not None: if not callable(self.config["env_task_fn"]): raise ValueError( "`env_task_fn` must be None or a callable taking " "[train_results, env, env_ctx] as args!") def fn(env, env_context, task_fn): new_task = task_fn(res, env, env_context) cur_task = env.get_task() if cur_task != new_task: env.set_task(new_task) fn = partial(fn, task_fn=self.config["env_task_fn"]) self.workers.foreach_env_with_context(fn) return res @staticmethod @override(Trainer) def _validate_config(config: PartialTrainerConfigDict, trainer_obj_or_none: Optional["Trainer"] = None): # Call super (Trainer) validation method first. Trainer._validate_config(config, trainer_obj_or_none) # Then call user defined one, if any. if validate_config is not None: validate_config(config) @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) @override(Trainer) def __getstate__(self): state = Trainer.__getstate__(self) state["train_exec_impl"] = ( self.train_exec_impl.shared_metrics.get().save()) return state @override(Trainer) def __setstate__(self, state): Trainer.__setstate__(self, state) self.train_exec_impl.shared_metrics.get().restore( state["train_exec_impl"]) @staticmethod @override(Trainer) def with_updates(**overrides) -> Type[Trainer]: """Build a copy of this trainer class with the specified overrides. Keyword Args: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. Returns: Type[Trainer]: A the Trainer sub-class using `original_kwargs` and `overrides`. Examples: >>> MyClass = SomeOtherClass.with_updates({"name": "Mine"}) >>> issubclass(MyClass, SomeOtherClass) ... False >>> issubclass(MyClass, Trainer) ... True """ return build_trainer(**dict(original_kwargs, **overrides)) def __repr__(self): return self._name trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
def build_trainer( name: str, *, default_config: Optional[TrainerConfigDict] = None, validate_config: Optional[Callable[[TrainerConfigDict], None]] = None, default_policy: Optional[Type[Policy]] = None, get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[Policy]]]] = None, validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, before_init: Optional[Callable[[Trainer], None]] = None, after_init: Optional[Callable[[Trainer], None]] = None, before_evaluate_fn: Optional[Callable[[Trainer], None]] = None, mixins: Optional[List[type]] = None, execution_plan: Optional[ Union[Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]], Callable[[Trainer, WorkerSet, TrainerConfigDict], Iterable[ResultDict]]]] = None, allow_unknown_configs: bool = False, allow_unknown_subkeys: Optional[List[str]] = None, override_all_subkeys_if_type_changes: Optional[List[str]] = None, ) -> Type[Trainer]: """Helper function for defining a custom Trainer class. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_policy. 2. Worker setup: before_init, execution_plan. 3. Post setup: after_init. Args: name: name of the trainer (e.g., "PPO") default_config: The default config dict of the algorithm, otherwise uses the Trainer default config. validate_config: Optional callable that takes the config to check for correctness. It may mutate the config as needed. default_policy: The default Policy class to use if `get_policy_class` returns None. get_policy_class: Optional callable that takes a config and returns the policy class or None. If None is returned, will use `default_policy` (which must be provided then). validate_env: Optional callable to validate the generated environment (only on worker=0). before_init: Optional callable to run before anything is constructed inside Trainer (Workers with Policies, execution plan, etc..). Takes the Trainer instance as argument. after_init: Optional callable to run at the end of trainer init (after all Workers and the exec. plan have been constructed). Takes the Trainer instance as argument. before_evaluate_fn: Callback to run before evaluation. This takes the trainer instance as argument. mixins: List of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. execution_plan: Optional callable that sets up the distributed execution workflow. allow_unknown_configs: Whether to allow unknown top-level config keys. allow_unknown_subkeys: List of top-level keys with value=dict, for which new sub-keys are allowed to be added to the value dict. Appends to Trainer class defaults. override_all_subkeys_if_type_changes: List of top level keys with value=dict, for which we always override the entire value (dict), iff the "type" key in that value dict changes. Appends to Trainer class defaults. Returns: A Trainer sub-class configured by the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy_class = default_policy def __init__(self, config: TrainerConfigDict = None, env: Union[str, EnvType, None] = None, logger_creator: Callable[[], Logger] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None): Trainer.__init__(self, config, env, logger_creator, remote_checkpoint_dir, sync_function_tpl) @override(base) def setup(self, config: PartialTrainerConfigDict): if allow_unknown_subkeys is not None: self._allow_unknown_subkeys += allow_unknown_subkeys self._allow_unknown_configs = allow_unknown_configs if override_all_subkeys_if_type_changes is not None: self._override_all_subkeys_if_type_changes += \ override_all_subkeys_if_type_changes Trainer.setup(self, config) def _init(self, config: TrainerConfigDict, env_creator: Callable[[EnvConfigDict], EnvType]): # No `get_policy_class` function. if get_policy_class is None: # Default_policy must be provided (unless in multi-agent mode, # where each policy can have its own default policy class). if not config["multiagent"]["policies"]: assert default_policy is not None # Query the function for a class to use. else: self._policy_class = get_policy_class(config) # If None returned, use default policy (must be provided). if self._policy_class is None: assert default_policy is not None self._policy_class = default_policy if before_init: before_init(self) # Creating all workers (excluding evaluation workers). self.workers = self._make_workers( env_creator=env_creator, validate_env=validate_env, policy_class=self._policy_class, config=config, num_workers=self.config["num_workers"]) self.train_exec_impl = self.execution_plan( self.workers, config, **self._kwargs_for_execution_plan()) if after_init: after_init(self) @override(Trainer) def validate_config(self, config: PartialTrainerConfigDict): # Call super (Trainer) validation method first. Trainer.validate_config(self, config) # Then call user defined one, if any. if validate_config is not None: validate_config(config) @staticmethod @override(Trainer) def execution_plan(workers, config, **kwargs): # `execution_plan` is provided, use it inside # `self.execution_plan()`. if execution_plan is not None: return execution_plan(workers, config, **kwargs) # If `execution_plan` is not provided (None), the Trainer will use # it's already existing default `execution_plan()` static method # instead. else: return Trainer.execution_plan(workers, config, **kwargs) @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) @staticmethod @override(Trainer) def with_updates(**overrides) -> Type[Trainer]: """Build a copy of this trainer class with the specified overrides. Keyword Args: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. Returns: Type[Trainer]: A the Trainer sub-class using `original_kwargs` and `overrides`. Examples: >>> from ray.rllib.agents.ppo import PPOTrainer >>> MyPPOClass = PPOTrainer.with_updates({"name": "MyPPO"}) >>> issubclass(MyPPOClass, PPOTrainer) False >>> issubclass(MyPPOClass, Trainer) True >>> trainer = MyPPOClass() >>> print(trainer) MyPPO """ return build_trainer(**dict(original_kwargs, **overrides)) def __repr__(self): return self._name trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
def build_trainer( name: str, default_policy: Optional[Policy], *, default_config: TrainerConfigDict = None, validate_config: Callable[[TrainerConfigDict], None] = None, get_policy_class: Callable[[TrainerConfigDict], Policy] = None, before_init: Callable[[Trainer], None] = None, after_init: Callable[[Trainer], None] = None, before_evaluate_fn: Callable[[Trainer], None] = None, mixins: List[type] = None, execution_plan: Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]] = default_execution_plan): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_policy 2. Worker setup: before_init, execution_plan 3. Post setup: after_init Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): The default config dict of the algorithm, otherwise uses the Trainer default config. validate_config (Optional[callable]): Optional callable that takes the config to check for correctness. It may mutate the config as needed. get_policy_class (Optional[callable]): Optional callable that takes a config and returns the policy class to override the default with. before_init (Optional[callable]): Optional callable to run at the start of trainer init that takes the trainer instance as argument. after_init (Optional[callable]): Optional callable to run at the end of trainer init that takes the trainer instance as argument. before_evaluate_fn (Optional[callable]): callback to run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. execution_plan (func): Setup the distributed execution workflow. Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_policy_class is None: self._policy = default_policy else: self._policy = get_policy_class(config) if before_init: before_init(self) # Creating all workers (excluding evaluation workers). self.workers = self._make_workers(env_creator, self._policy, config, self.config["num_workers"]) self.execution_plan = execution_plan self.train_exec_impl = execution_plan(self.workers, config) if after_init: after_init(self) @override(Trainer) def step(self): res = next(self.train_exec_impl) return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["train_exec_impl"] = ( self.train_exec_impl.shared_metrics.get().save()) return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.train_exec_impl.shared_metrics.get().restore( state["train_exec_impl"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. Arguments: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = staticmethod(with_updates) trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
def build_eager_tf_policy( name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, compute_gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_learn_fetches_fn=None, extra_action_out_fn=None, validate_spaces=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, action_distribution_fn=None, mixins=None, get_batch_divisibility_req=None, # Deprecated args. obs_include_prev_action_reward=DEPRECATED_VALUE, extra_action_fetches_fn=None, gradients_fn=None, ): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging much simpler, but has lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"framework": "tfe"} in the trainer config to have it automatically be converted to an eager policy. This has the same signature as build_tf_policy().""" base = add_mixins(Policy, mixins) if obs_include_prev_action_reward != DEPRECATED_VALUE: deprecation_warning(old="obs_include_prev_action_reward", error=False) if extra_action_fetches_fn is not None: deprecation_warning(old="extra_action_fetches_fn", new="extra_action_out_fn", error=False) extra_action_out_fn = extra_action_fetches_fn if gradients_fn is not None: deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=False) compute_gradients_fn = gradients_fn class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): # If this class runs as a @ray.remote actor, eager mode may not # have been activated yet. if not tf1.executing_eagerly(): tf1.enable_eager_execution() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) # Log device and worker index. from ray.rllib.evaluation.rollout_worker import get_global_worker worker = get_global_worker() worker_idx = worker.worker_index if worker else 0 if get_gpu_devices(): logger.info( "TF-eager Policy (worker={}) running on GPU.".format( worker_idx if worker_idx > 0 else "local")) else: logger.info( "TF-eager Policy (worker={}) running on CPU.".format( worker_idx if worker_idx > 0 else "local")) self._is_training = False # Only for `config.eager_tracing=True`: A counter to keep track of # how many times an eager-traced method (e.g. # `self._compute_actions_helper`) has been re-traced by tensorflow. # We will raise an error if more than n re-tracings have been # detected, since this would considerably slow down execution. # The variable below should only get incremented during the # tf.function trace operations, never when calling the already # traced function after that. self._re_trace_counter = 0 self._loss_initialized = False # To ensure backward compatibility: # Old way: If `loss` provided here, use as-is (as a function). if loss_fn is not None: self._loss = loss_fn # New way: Convert the overridden `self.loss` into a plain # function, so it can be called the same way as `loss` would # be, ensuring backward compatibility. elif self.loss.__func__.__qualname__ != "Policy.loss": self._loss = self.loss.__func__ # `loss` not provided nor overridden from Policy -> Set to None. else: self._loss = None self.batch_divisibility_req = (get_batch_divisibility_req(self) if callable(get_batch_divisibility_req) else (get_batch_divisibility_req or 1)) self._max_seq_len = config["model"]["max_seq_len"] if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) # Lock used for locking some methods on the object-level. # This prevents possible race conditions when calling the model # first, then its value function (e.g. in a loss function), in # between of which another model call is made (e.g. to compute an # action). self._lock = threading.RLock() # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() self.exploration = self._create_exploration() self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) if before_loss_init: before_loss_init(self, observation_space, action_space, config) if optimizer_fn: optimizers = optimizer_fn(self, config) else: optimizers = tf.keras.optimizers.Adam(config["lr"]) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) # The list of local (tf) optimizers (one per loss term). self._optimizers: List[LocalOptimizer] = optimizers # Backward compatibility: A user's policy may only support a single # loss term and optimizer (no lists). self._optimizer: LocalOptimizer = optimizers[ 0] if optimizers else None self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) self._loss_initialized = True if after_init: after_init(self, observation_space, action_space, config) # Got to reset global_timestep again after fake run-throughs. self.global_timestep = 0 @override(Policy) def compute_actions_from_input_dict( self, input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List[Episode]] = None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: if not self.config.get( "eager_tracing") and not tf1.executing_eagerly(): tf1.enable_eager_execution() self._is_training = False explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep if isinstance(timestep, tf.Tensor): timestep = int(timestep.numpy()) # Pass lazy (eager) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) input_dict.set_training(False) # Pack internal state inputs into (separate) list. state_batches = [ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] ] self._state_in = state_batches self._is_recurrent = state_batches != [] # Call the exploration before_compute_actions hook. self.exploration.before_compute_actions(timestep=timestep, explore=explore, tf_sess=self.get_session()) ret = self._compute_actions_helper( input_dict, state_batches, # TODO: Passing episodes into a traced method does not work. None if self.config["eager_tracing"] else episodes, explore, timestep, ) # Update our global timestep by the batch size. self.global_timestep += int(tree.flatten(ret[0])[0].shape[0]) return convert_to_numpy(ret) @override(Policy) def compute_actions( self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs, ): # Create input dict to simply pass the entire call to # self.compute_actions_from_input_dict(). input_dict = SampleBatch( { SampleBatch.CUR_OBS: obs_batch, }, _is_training=tf.constant(False), ) if state_batches is not None: for s in enumerate(state_batches): input_dict["state_in_{i}"] = s if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch if info_batch is not None: input_dict[SampleBatch.INFOS] = info_batch return self.compute_actions_from_input_dict( input_dict=input_dict, explore=explore, timestep=timestep, episodes=episodes, **kwargs, ) @with_lock @override(Policy) def compute_log_likelihoods( self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, actions_normalized=True, ): if action_sampler_fn and action_distribution_fn is None: raise ValueError("Cannot compute log-prob/likelihood w/o an " "`action_distribution_fn` and a provided " "`action_sampler_fn`!") seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) input_batch = SampleBatch( {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)}, _is_training=False, ) if prev_action_batch is not None: input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor( prev_action_batch) if prev_reward_batch is not None: input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor( prev_reward_batch) # Exploration hook before each forward pass. self.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if action_distribution_fn: dist_inputs, dist_class, _ = action_distribution_fn( self, self.model, input_batch, explore=False, is_training=False) # Default log-likelihood calculation. else: dist_inputs, _ = self.model(input_batch, state_batches, seq_lens) dist_class = self.dist_class action_dist = dist_class(dist_inputs, self.model) # Normalize actions if necessary. if not actions_normalized and self.config["normalize_actions"]: actions = normalize_action(actions, self.action_space_struct) log_likelihoods = action_dist.logp(actions) return log_likelihoods @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): assert tf.executing_eagerly() # Call super's postprocess_trajectory first. sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @with_lock @override(Policy) def learn_on_batch(self, postprocessed_batch): # Callback handling. learn_stats = {} self.callbacks.on_learn_on_batch(policy=self, train_batch=postprocessed_batch, result=learn_stats) pad_batch_to_sequences_of_same_size( postprocessed_batch, max_seq_len=self._max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, view_requirements=self.view_requirements, ) self._is_training = True postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch) postprocessed_batch.set_training(True) stats = self._learn_on_batch_helper(postprocessed_batch) stats.update({ "custom_metrics": learn_stats, NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, }) return convert_to_numpy(stats) @override(Policy) def compute_gradients( self, postprocessed_batch: SampleBatch ) -> Tuple[ModelGradients, Dict[str, TensorType]]: pad_batch_to_sequences_of_same_size( postprocessed_batch, shuffle=False, max_seq_len=self._max_seq_len, batch_divisibility_req=self.batch_divisibility_req, view_requirements=self.view_requirements, ) self._is_training = True self._lazy_tensor_dict(postprocessed_batch) postprocessed_batch.set_training(True) grads_and_vars, grads, stats = self._compute_gradients_helper( postprocessed_batch) return convert_to_numpy((grads, stats)) @override(Policy) def apply_gradients(self, gradients: ModelGradients) -> None: self._apply_gradients_helper( list( zip( [(tf.convert_to_tensor(g) if g is not None else None) for g in gradients], self.model.trainable_variables(), ))) @override(Policy) def get_weights(self, as_dict=False): variables = self.variables() if as_dict: return {v.name: v.numpy() for v in variables} return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): variables = self.variables() assert len(weights) == len(variables), (len(weights), len(variables)) for v, w in zip(variables, weights): v.assign(w) @override(Policy) def get_exploration_state(self): return convert_to_numpy(self.exploration.get_state()) @override(Policy) def is_recurrent(self): return self._is_recurrent @override(Policy) def num_state_tensors(self): return len(self._state_inputs) @override(Policy) def get_initial_state(self): if hasattr(self, "model"): return self.model.get_initial_state() return [] @override(Policy) def get_state(self): state = super().get_state() if self._optimizer and len(self._optimizer.variables()) > 0: state["_optimizer_variables"] = self._optimizer.variables() # Add exploration state. state["_exploration_state"] = self.exploration.get_state() return state @override(Policy) def set_state(self, state): state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.get("_optimizer_variables", None) if optimizer_vars and self._optimizer.variables(): logger.warning( "Cannot restore an optimizer's state for tf eager! Keras " "is not able to save the v1.x optimizers (from " "tf.compat.v1.train) since they aren't compatible with " "checkpoints.") for opt_var, value in zip(self._optimizer.variables(), optimizer_vars): opt_var.assign(value) # Set exploration's state. if hasattr(self, "exploration") and "_exploration_state" in state: self.exploration.set_state(state=state["_exploration_state"]) # Then the Policy's (NN) weights. super().set_state(state) @override(Policy) def export_checkpoint(self, export_dir): raise NotImplementedError # TODO: implement this @override(Policy) def export_model(self, export_dir): raise NotImplementedError # TODO: implement this def variables(self): """Return the list of all savable variables for this policy.""" if isinstance(self.model, tf.keras.Model): return self.model.variables else: return self.model.variables() def loss_initialized(self): return self._loss_initialized @with_lock def _compute_actions_helper(self, input_dict, state_batches, episodes, explore, timestep): # Increase the tracing counter to make sure we don't re-trace too # often. If eager_tracing=True, this counter should only get # incremented during the @tf.function trace operations, never when # calling the already traced function after that. self._re_trace_counter += 1 # Calculate RNN sequence lengths. batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None # Add default and custom fetches. extra_fetches = {} # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): if action_sampler_fn: dist_inputs = None state_out = [] actions, logp = action_sampler_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, episodes=episodes, ) else: if action_distribution_fn: # Try new action_distribution_fn signature, supporting # state_batches and seq_lens. try: ( dist_inputs, self.dist_class, state_out, ) = action_distribution_fn( self, self.model, input_dict=input_dict, state_batches=state_batches, seq_lens=seq_lens, explore=explore, timestep=timestep, is_training=False, ) # Trying the old way (to stay backward compatible). # TODO: Remove in future. except TypeError as e: if ("positional argument" in e.args[0] or "unexpected keyword argument" in e.args[0]): ( dist_inputs, self.dist_class, state_out, ) = action_distribution_fn( self, self.model, input_dict[SampleBatch.OBS], explore=explore, timestep=timestep, is_training=False, ) else: raise e elif isinstance(self.model, tf.keras.Model): input_dict = SampleBatch(input_dict, seq_lens=seq_lens) if state_batches and "state_in_0" not in input_dict: for i, s in enumerate(state_batches): input_dict[f"state_in_{i}"] = s self._lazy_tensor_dict(input_dict) dist_inputs, state_out, extra_fetches = self.model( input_dict) else: dist_inputs, state_out = self.model( input_dict, state_batches, seq_lens) action_dist = self.dist_class(dist_inputs, self.model) # Get the exploration action from the forward results. actions, logp = self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore, ) # Action-logp and action-prob. if logp is not None: extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) extra_fetches[SampleBatch.ACTION_LOGP] = logp # Action-dist inputs. if dist_inputs is not None: extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs # Custom extra fetches. if extra_action_out_fn: extra_fetches.update(extra_action_out_fn(self)) return actions, state_out, extra_fetches def _learn_on_batch_helper(self, samples): # Increase the tracing counter to make sure we don't re-trace too # often. If eager_tracing=True, this counter should only get # incremented during the @tf.function trace operations, never when # calling the already traced function after that. self._re_trace_counter += 1 with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, _, stats = self._compute_gradients_helper( samples) self._apply_gradients_helper(grads_and_vars) return stats def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) @with_lock def _compute_gradients_helper(self, samples): """Computes and returns grads as eager tensors.""" # Increase the tracing counter to make sure we don't re-trace too # often. If eager_tracing=True, this counter should only get # incremented during the @tf.function trace operations, never when # calling the already traced function after that. self._re_trace_counter += 1 # Gather all variables for which to calculate losses. if isinstance(self.model, tf.keras.Model): variables = self.model.trainable_variables else: variables = self.model.trainable_variables() # Calculate the loss(es) inside a tf GradientTape. with tf.GradientTape( persistent=compute_gradients_fn is not None) as tape: losses = self._loss(self, self.model, self.dist_class, samples) losses = force_list(losses) # User provided a compute_gradients_fn. if compute_gradients_fn: # Wrap our tape inside a wrapper, such that the resulting # object looks like a "classic" tf.optimizer. This way, custom # compute_gradients_fn will work on both tf static graph # and tf-eager. optimizer = OptimizerWrapper(tape) # More than one loss terms/optimizers. if self.config["_tf_policy_handles_more_than_one_loss"]: grads_and_vars = compute_gradients_fn( self, [optimizer] * len(losses), losses) # Only one loss and one optimizer. else: grads_and_vars = [ compute_gradients_fn(self, optimizer, losses[0]) ] # Default: Compute gradients using the above tape. else: grads_and_vars = [ list(zip(tape.gradient(loss, variables), variables)) for loss in losses ] if log_once("grad_vars"): for g_and_v in grads_and_vars: for g, v in g_and_v: if g is not None: logger.info(f"Optimizing variable {v.name}") # `grads_and_vars` is returned a list (len=num optimizers/losses) # of lists of (grad, var) tuples. if self.config["_tf_policy_handles_more_than_one_loss"]: grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars] # `grads_and_vars` is returned as a list of (grad, var) tuples. else: grads_and_vars = grads_and_vars[0] grads = [g for g, _ in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, grads, stats def _apply_gradients_helper(self, grads_and_vars): # Increase the tracing counter to make sure we don't re-trace too # often. If eager_tracing=True, this counter should only get # incremented during the @tf.function trace operations, never when # calling the already traced function after that. self._re_trace_counter += 1 if apply_gradients_fn: if self.config["_tf_policy_handles_more_than_one_loss"]: apply_gradients_fn(self, self._optimizers, grads_and_vars) else: apply_gradients_fn(self, self._optimizer, grads_and_vars) else: if self.config["_tf_policy_handles_more_than_one_loss"]: for i, o in enumerate(self._optimizers): o.apply_gradients([(g, v) for g, v in grads_and_vars[i] if g is not None]) else: self._optimizer.apply_gradients([(g, v) for g, v in grads_and_vars if g is not None]) def _stats(self, outputs, samples, grads): fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { k: v for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: fetches.update( {k: v for k, v in extra_learn_fetches_fn(self).items()}) if grad_stats_fn: fetches.update({ k: v for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch): # TODO: (sven): Keep for a while to ensure backward compatibility. if not isinstance(postprocessed_batch, SampleBatch): postprocessed_batch = SampleBatch(postprocessed_batch) postprocessed_batch.set_get_interceptor(_convert_to_tf) return postprocessed_batch @classmethod def with_tracing(cls): return traced_eager_policy(cls) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls
def build_torch_policy(name, loss_fn, get_default_config=None, stats_fn=None, postprocess_fn=None, extra_action_out_fn=None, extra_grad_process_fn=None, optimizer_fn=None, before_init=None, after_init=None, make_model_and_action_dist=None, mixins=None): """Helper function for creating a torch policy at runtime. Arguments: name (str): name of the policy (e.g., "PPOTorchPolicy") loss_fn (func): function that returns a loss tensor as arguments (policy, model, dist_class, train_batch) get_default_config (func): optional function that returns the default config to merge with any overrides stats_fn (func): optional function that returns a dict of values given the policy and batch input tensors postprocess_fn (func): optional experience postprocessing function that takes the same args as Policy.postprocess_trajectory() extra_action_out_fn (func): optional function that returns a dict of extra values to include in experiences extra_grad_process_fn (func): optional function that is called after gradients are computed and returns processing info optimizer_fn (func): optional function that returns a torch optimizer given the policy and config before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor after_init (func): optional function to run at the end of policy init that takes the same arguments as the policy constructor make_model_and_action_dist (func): optional func that takes the same arguments as policy init and returns a tuple of model instance and torch action distribution class. If not specified, the default model and action dist from the catalog will be used mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class Returns: a TorchPolicy instance that uses the specified args """ original_kwargs = locals().copy() base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if before_init: before_init(self, obs_space, action_space, config) if make_model_and_action_dist: self.model, self.dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: TorchPolicy::make_model_and_action_dist must " \ "return a TorchModelV2 object!" else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2(obs_space, action_space, logit_dim, self.config["model"], framework="torch") TorchPolicy.__init__(self, obs_space, action_space, config, self.model, loss_fn, self.dist_class) if after_init: after_init(self, obs_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): if not postprocess_fn: return sample_batch # Do all post-processing always with no_grad(). # Not using this here will introduce a memory leak (issue #6962). with torch.no_grad(): return postprocess_fn( self, convert_to_non_torch_type(sample_batch), convert_to_non_torch_type(other_agent_batches), episode) @override(TorchPolicy) def extra_grad_process(self): if extra_grad_process_fn: return extra_grad_process_fn(self) else: return TorchPolicy.extra_grad_process(self) @override(TorchPolicy) def extra_action_out(self, input_dict, state_batches, model, action_dist=None): with torch.no_grad(): if extra_action_out_fn: stats_dict = extra_action_out_fn(self, input_dict, state_batches, model, action_dist) else: stats_dict = TorchPolicy.extra_action_out( self, input_dict, state_batches, model, action_dist) return convert_to_non_torch_type(stats_dict) @override(TorchPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return TorchPolicy.optimizer(self) @override(TorchPolicy) def extra_grad_info(self, train_batch): with torch.no_grad(): if stats_fn: stats_dict = stats_fn(self, train_batch) else: stats_dict = TorchPolicy.extra_grad_info(self, train_batch) return convert_to_non_torch_type(stats_dict) def with_updates(**overrides): return build_torch_policy(**dict(original_kwargs, **overrides)) policy_cls.with_updates = staticmethod(with_updates) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_tf_policy(name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_action_fetches_fn=None, extra_action_feed_fn=None, extra_learn_fetches_fn=None, extra_learn_feed_fn=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, mixins=None, get_batch_divisibility_req=None, obs_include_prev_action_reward=True): """Helper function for creating a dynamic tf policy at runtime. Functions will be run in this order to initialize the policy: 1. Placeholder setup: postprocess_fn 2. Loss init: loss_fn, stats_fn 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn, grad_stats_fn This means that you can e.g., depend on any policy attributes created in the running of `loss_fn` in later functions such as `stats_fn`. In eager mode (to be implemented), the following functions will be run repeatedly on each eager execution: loss_fn, stats_fn This means that these functions should not define any variables internally, otherwise they will fail in eager mode execution. Variable should only be created in make_model (if defined). Arguments: name (str): name of the policy (e.g., "PPOTFPolicy") loss_fn (func): function that returns a loss tensor the policy, and dict of experience tensor placeholdes get_default_config (func): optional function that returns the default config to merge with any overrides postprocess_fn (func): optional experience postprocessing function that takes the same args as Policy.postprocess_trajectory() stats_fn (func): optional function that returns a dict of TF fetches given the policy and batch input tensors optimizer_fn (func): optional function that returns a tf.Optimizer given the policy and config gradients_fn (func): optional function that returns a list of gradients given (policy, optimizer, loss). If not specified, this defaults to optimizer.compute_gradients(loss) apply_gradients_fn (func): optional function that returns an apply gradients op given (policy, optimizer, grads_and_vars) grad_stats_fn (func): optional function that returns a dict of TF fetches given the policy and loss gradient tensors extra_action_fetches_fn (func): optional function that returns a dict of TF fetches given the policy object extra_action_feed_fn (func): optional function that returns a feed dict to also feed to TF when computing actions extra_learn_fetches_fn (func): optional function that returns a dict of extra values to fetch and return when learning on a batch extra_learn_feed_fn (func): optional function that returns a feed dict to also feed to TF when learning on a batch before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor before_loss_init (func): optional function to run prior to loss init that takes the same arguments as the policy constructor after_init (func): optional function to run at the end of policy init that takes the same arguments as the policy constructor make_model (func): optional function that returns a ModelV2 object given (policy, obs_space, action_space, config). All policy variables should be created in this function. If not specified, a default model will be created. action_sampler_fn (func): optional function that returns a tuple of action and action prob tensors given (policy, model, input_dict, obs_space, action_space, config). If not specified, a default action distribution will be used. mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches obs_include_prev_action_reward (bool): whether to include the previous action and reward in the model input Returns: a DynamicTFPolicy instance that uses the specified args """ original_kwargs = locals().copy() base = add_mixins(DynamicTFPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config, existing_model=None, existing_inputs=None): if get_default_config: config = dict(get_default_config(), **config) if before_init: before_init(self, obs_space, action_space, config) def before_loss_init_wrapper(policy, obs_space, action_space, config): if before_loss_init: before_loss_init(policy, obs_space, action_space, config) if extra_action_fetches_fn is None: self._extra_action_fetches = {} else: self._extra_action_fetches = extra_action_fetches_fn(self) DynamicTFPolicy.__init__( self, obs_space, action_space, config, loss_fn, stats_fn=stats_fn, grad_stats_fn=grad_stats_fn, before_loss_init=before_loss_init_wrapper, make_model=make_model, action_sampler_fn=action_sampler_fn, existing_model=existing_model, existing_inputs=existing_inputs, get_batch_divisibility_req=get_batch_divisibility_req, obs_include_prev_action_reward=obs_include_prev_action_reward) if after_init: after_init(self, obs_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): if not postprocess_fn: return sample_batch return postprocess_fn(self, sample_batch, other_agent_batches, episode) @override(TFPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return TFPolicy.optimizer(self) @override(TFPolicy) def gradients(self, optimizer, loss): if gradients_fn: return gradients_fn(self, optimizer, loss) else: return TFPolicy.gradients(self, optimizer, loss) @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): if apply_gradients_fn: return apply_gradients_fn(self, optimizer, grads_and_vars) else: return TFPolicy.build_apply_op(self, optimizer, grads_and_vars) @override(TFPolicy) def extra_compute_action_fetches(self): return dict(TFPolicy.extra_compute_action_fetches(self), **self._extra_action_fetches) @override(TFPolicy) def extra_compute_action_feed_dict(self): if extra_action_feed_fn: return extra_action_feed_fn(self) else: return TFPolicy.extra_compute_action_feed_dict(self) @override(TFPolicy) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: # auto-add empty learner stats dict if needed return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self)) else: return TFPolicy.extra_compute_grad_fetches(self) @override(TFPolicy) def extra_compute_grad_feed_dict(self): if extra_learn_feed_fn: return extra_learn_feed_fn(self) else: return TFPolicy.extra_compute_grad_feed_dict(self) @staticmethod def with_updates(**overrides): return build_tf_policy(**dict(original_kwargs, **overrides)) policy_cls.with_updates = with_updates policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_tf_policy(name: str, *, loss_fn: Callable[ [Policy, ModelV2, type, SampleBatch], TensorType], get_default_config: Optional[ Callable[[None], TrainerConfigDict]] = None, postprocess_fn: Optional[Callable[ [Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode"], None]] = None, stats_fn: Optional[Callable[ [Policy, SampleBatch], Dict[str, TensorType]]] = None, optimizer_fn: Optional[Callable[ [Policy, TrainerConfigDict], "tf.keras.optimizers.Optimizer"]] = None, gradients_fn: Optional[Callable[ [Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]] = None, apply_gradients_fn: Optional[Callable[ [Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"]] = None, grad_stats_fn: Optional[Callable[ [Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]] = None, extra_action_fetches_fn: Optional[Callable[ [Policy], Dict[str, TensorType]]] = None, extra_learn_fetches_fn: Optional[Callable[ [Policy], Dict[str, TensorType]]] = None, validate_spaces: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_loss_init: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]] = None, after_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, make_model: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]] = None, action_sampler_fn: Optional[Callable[ [TensorType, List[TensorType]], Tuple[ TensorType, TensorType]]] = None, action_distribution_fn: Optional[Callable[ [Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]] = None, mixins: Optional[List[type]] = None, get_batch_divisibility_req: Optional[Callable[ [Policy], int]] = None, obs_include_prev_action_reward: bool = True): """Helper function for creating a dynamic tf policy at runtime. Functions will be run in this order to initialize the policy: 1. Placeholder setup: postprocess_fn 2. Loss init: loss_fn, stats_fn 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn, grad_stats_fn This means that you can e.g., depend on any policy attributes created in the running of `loss_fn` in later functions such as `stats_fn`. In eager mode, the following functions will be run repeatedly on each eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn, and grad_stats_fn. This means that these functions should not define any variables internally, otherwise they will fail in eager mode execution. Variable should only be created in make_model (if defined). Args: name (str): Name of the policy (e.g., "PPOTFPolicy"). loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]): Callable for calculating a loss tensor. get_default_config (Optional[Callable[[None], TrainerConfigDict]]): Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided PartialTrainerConfigDict as dict for this Policy. postprocess_fn (Optional[Callable[[Policy, SampleBatch, List[SampleBatch], MultiAgentEpisode], None]]): Optional callable for post-processing experience batches (called after the super's `postprocess_trajectory` method). stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional callable that returns a dict of TF tensors to fetch given the policy and batch input tensors. If None, will not compute any stats. optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict], "tf.keras.optimizers.Optimizer"]]): Optional callable that returns a tf.Optimizer given the policy and config. If None, will call the base class' `optimizer()` method instead (which returns a tf1.train.AdamOptimizer). gradients_fn (Optional[Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]): Optional callable that returns a list of gradients. If None, this defaults to optimizer.compute_gradients([loss]). apply_gradients_fn (Optional[Callable[[Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"]]): Optional callable that returns an apply gradients op given policy, tf-optimizer, and grads_and_vars. If None, will call the base class' `build_apply_op()` method instead. grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]]): Optional callable that returns a dict of TF fetches given the policy, batch input, and gradient tensors. If None, will not collect any gradient stats. extra_action_fetches_fn (Optional[Callable[[Policy], Dict[str, TensorType]]]): Optional callable that returns a dict of TF fetches given the policy object. If None, will not perform any extra fetches. extra_learn_fetches_fn (Optional[Callable[[Policy], Dict[str, TensorType]]]): Optional callable that returns a dict of extra values to fetch and return when learning on a batch. If None, will call the base class' `extra_compute_grad_fetches()` method instead. validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable that takes the Policy, observation_space, action_space, and config to check the spaces for correctness. If None, no spaces checking will be done. before_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the beginning of policy init that takes the same arguments as the policy constructor. If None, this step will be skipped. before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to run prior to loss init. If None, this step will be skipped. after_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the end of policy init. If None, this step will be skipped. make_model (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable that returns a ModelV2 object. All policy variables should be created in this function. If None, a default ModelV2 object will be created. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]]): A callable returning a sampled action and its log-likelihood given observation and state inputs. If None, will either use `action_distribution_fn` or compute actions by calling self.model, then sampling from the so parameterized action distribution. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): Optional callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). If None, will either use `action_sampler_fn` or compute actions by calling self.model, then sampling from the so parameterized action distribution. mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. obs_include_prev_action_reward (bool): Whether to include the previous action and reward in the model input. Returns: a DynamicTFPolicy instance that uses the specified args """ original_kwargs = locals().copy() base = add_mixins(DynamicTFPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config, existing_model=None, existing_inputs=None): if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, obs_space, action_space, config) if before_init: before_init(self, obs_space, action_space, config) def before_loss_init_wrapper(policy, obs_space, action_space, config): if before_loss_init: before_loss_init(policy, obs_space, action_space, config) if extra_action_fetches_fn is None: self._extra_action_fetches = {} else: self._extra_action_fetches = extra_action_fetches_fn(self) DynamicTFPolicy.__init__( self, obs_space=obs_space, action_space=action_space, config=config, loss_fn=loss_fn, stats_fn=stats_fn, grad_stats_fn=grad_stats_fn, before_loss_init=before_loss_init_wrapper, make_model=make_model, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, existing_model=existing_model, existing_inputs=existing_inputs, get_batch_divisibility_req=get_batch_divisibility_req, obs_include_prev_action_reward=obs_include_prev_action_reward) if after_init: after_init(self, obs_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): # Call super's postprocess_trajectory first. sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(TFPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return base.optimizer(self) @override(TFPolicy) def gradients(self, optimizer, loss): if gradients_fn: return gradients_fn(self, optimizer, loss) else: return base.gradients(self, optimizer, loss) @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): if apply_gradients_fn: return apply_gradients_fn(self, optimizer, grads_and_vars) else: return base.build_apply_op(self, optimizer, grads_and_vars) @override(TFPolicy) def extra_compute_action_fetches(self): return dict( base.extra_compute_action_fetches(self), **self._extra_action_fetches) @override(TFPolicy) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: # Auto-add empty learner stats dict if needed. return dict({ LEARNER_STATS_KEY: {} }, **extra_learn_fetches_fn(self)) else: return base.extra_compute_grad_fetches(self) def with_updates(**overrides): """Allows creating a TFPolicy cls based on settings of another one. Keyword Args: **overrides: The settings (passed into `build_tf_policy`) that should be different from the class that this method is called on. Returns: type: A new TFPolicy sub-class. Examples: >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates( .. name="MySpecialDQNPolicyClass", .. loss_function=[some_new_loss_function], .. ) """ return build_tf_policy(**dict(original_kwargs, **overrides)) def as_eager(): return eager_tf_policy.build_eager_tf_policy(**original_kwargs) policy_cls.with_updates = staticmethod(with_updates) policy_cls.as_eager = staticmethod(as_eager) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_torch_policy( name: str, *, loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType], get_default_config: Optional[Callable[[], TrainerConfigDict]] = None, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, postprocess_fn: Optional[Callable[ [Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode"], None]] = None, extra_action_out_fn: Optional[Callable[[ Policy, Dict[str, TensorType], List[TensorType], ModelV2, TorchDistributionWrapper ], Dict[str, TensorType]]] = None, extra_grad_process_fn: Optional[ Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]] = None, # TODO: (sven) Replace "fetches" with "process". extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict], "torch.optim.Optimizer"]] = None, validate_spaces: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, after_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[ Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]] = None, make_model: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]] = None, make_model_and_action_dist: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], Tuple[ModelV2, TorchDistributionWrapper]]] = None, apply_gradients_fn: Optional[Callable[ [Policy, "torch.optim.Optimizer"], None]] = None, mixins: Optional[List[type]] = None, training_view_requirements_fn: Optional[Callable[[], Dict[ str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None): """Helper function for creating a torch policy class at runtime. Args: name (str): name of the policy (e.g., "PPOTorchPolicy") loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]): Callable that returns a loss tensor. get_default_config (Optional[Callable[[None], TrainerConfigDict]]): Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided PartialTrainerConfigDict as dict for this Policy. postprocess_fn (Optional[Callable[[Policy, SampleBatch, List[SampleBatch], MultiAgentEpisode], None]]): Optional callable for post-processing experience batches (called after the super's `postprocess_trajectory` method). stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional callable that returns a dict of values given the policy and batch input tensors. If None, will use `TorchPolicy.extra_grad_info()` instead. extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType, List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, TensorType]]]): Optional callable that returns a dict of extra values to include in experiences. If None, no extra computations will be performed. extra_grad_process_fn (Optional[Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]): Optional callable that is called after gradients are computed and returns a processing info dict. If None, will call the `TorchPolicy.extra_grad_process()` method instead. # TODO: (sven) dissolve naming mismatch between "learn" and "compute.." extra_learn_fetches_fn (Optional[Callable[[Policy], Dict[str, TensorType]]]): Optional callable that returns a dict of extra tensors from the policy after loss evaluation. If None, will call the `TorchPolicy.extra_compute_grad_fetches()` method instead. optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict], "torch.optim.Optimizer"]]): Optional callable that returns a torch optimizer given the policy and config. If None, will call the `TorchPolicy.optimizer()` method instead (which returns a torch Adam optimizer). validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable that takes the Policy, observation_space, action_space, and config to check for correctness. If None, no spaces checking will be done. before_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the beginning of `Policy.__init__` that takes the same arguments as the Policy constructor. If None, this step will be skipped. after_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the end of policy init that takes the same arguments as the policy constructor. If None, this step will be skipped. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]]): Optional callable returning a sampled action and its log-likelihood given some (obs and state) inputs. If None, will either use `action_distribution_fn` or compute actions by calling self.model, then sampling from the so parameterized action distribution. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): A callable that takes the Policy, Model, the observation batch, an explore-flag, a timestep, and an is_training flag and returns a tuple of a) distribution inputs (parameters), b) a dist-class to generate an action distribution object from, and c) internal-state outputs (empty list if not applicable). If None, will either use `action_sampler_fn` or compute actions by calling self.model, then sampling from the parameterized action distribution. make_model (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable that takes the same arguments as Policy.__init__ and returns a model instance. The distribution class will be determined automatically. Note: Only one of `make_model` or `make_model_and_action_dist` should be provided. If both are None, a default Model will be created. make_model_and_action_dist (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], Tuple[ModelV2, TorchDistributionWrapper]]]): Optional callable that takes the same arguments as Policy.__init__ and returns a tuple of model instance and torch action distribution class. Note: Only one of `make_model` or `make_model_and_action_dist` should be provided. If both are None, a default Model will be created. apply_gradients_fn (Optional[Callable[[Policy, "torch.optim.Optimizer"], None]]): Optional callable that takes a grads list and applies these to the Model's parameters. If None, will call the `TorchPolicy.apply_gradients()` method instead. mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class. training_view_requirements_fn (Callable[[], Dict[str, ViewRequirement]]): An optional callable to retrieve additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. Returns: type: TorchPolicy child class constructed from the specified args. """ original_kwargs = locals().copy() base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if validate_spaces: validate_spaces(self, obs_space, action_space, self.config) if before_init: before_init(self, obs_space, action_space, self.config) # Model is customized (use default action dist class). if make_model: assert make_model_and_action_dist is None, \ "Either `make_model` or `make_model_and_action_dist`" \ " must be None!" self.model = make_model(self, obs_space, action_space, config) dist_class, _ = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") # Model and action dist class are customized. elif make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Use default model and default action dist. else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="torch", **self.config["model"].get("custom_model_config", {})) # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: Generated Model must be a TorchModelV2 object!" TorchPolicy.__init__( self, observation_space=obs_space, action_space=action_space, config=config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) if after_init: after_init(self, obs_space, action_space, config) @override(TorchPolicy) def training_view_requirements(self): req = super().training_view_requirements() if callable(training_view_requirements_fn): req.update(training_view_requirements_fn(self)) return req @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): # Do all post-processing always with no_grad(). # Not using this here will introduce a memory leak (issue #6962). with torch.no_grad(): # Call super's postprocess_trajectory first. sample_batch = super().postprocess_trajectory( convert_to_non_torch_type(sample_batch), convert_to_non_torch_type(other_agent_batches), episode) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(TorchPolicy) def extra_grad_process(self, optimizer, loss): """Called after optimizer.zero_grad() and loss.backward() calls. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping. """ if extra_grad_process_fn: return extra_grad_process_fn(self, optimizer, loss) else: return TorchPolicy.extra_grad_process(self, optimizer, loss) @override(TorchPolicy) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: fetches = convert_to_non_torch_type( extra_learn_fetches_fn(self)) # Auto-add empty learner stats dict if needed. return dict({LEARNER_STATS_KEY: {}}, **fetches) else: return TorchPolicy.extra_compute_grad_fetches(self) @override(TorchPolicy) def apply_gradients(self, gradients): if apply_gradients_fn: apply_gradients_fn(self, gradients) else: TorchPolicy.apply_gradients(self, gradients) @override(TorchPolicy) def extra_action_out(self, input_dict, state_batches, model, action_dist): with torch.no_grad(): if extra_action_out_fn: stats_dict = extra_action_out_fn(self, input_dict, state_batches, model, action_dist) else: stats_dict = TorchPolicy.extra_action_out( self, input_dict, state_batches, model, action_dist) return convert_to_non_torch_type(stats_dict) @override(TorchPolicy) def optimizer(self): if optimizer_fn: optimizers = optimizer_fn(self, self.config) else: optimizers = TorchPolicy.optimizer(self) optimizers = force_list(optimizers) if hasattr(self, "exploration"): exploration_optimizers = force_list( self.exploration.get_exploration_optimizer(self.config)) optimizers.extend(exploration_optimizers) return optimizers @override(TorchPolicy) def extra_grad_info(self, train_batch): with torch.no_grad(): if stats_fn: stats_dict = stats_fn(self, train_batch) else: stats_dict = TorchPolicy.extra_grad_info(self, train_batch) return convert_to_non_torch_type(stats_dict) def with_updates(**overrides): """Allows creating a TorchPolicy cls based on settings of another one. Keyword Args: **overrides: The settings (passed into `build_torch_policy`) that should be different from the class that this method is called on. Returns: type: A new TorchPolicy sub-class. Examples: >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates( .. name="MySpecialDQNPolicyClass", .. loss_function=[some_new_loss_function], .. ) """ return build_torch_policy(**dict(original_kwargs, **overrides)) policy_cls.with_updates = staticmethod(with_updates) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_torch_policy(name, *, loss_fn, get_default_config=None, stats_fn=None, postprocess_fn=None, extra_action_out_fn=None, extra_grad_process_fn=None, extra_learn_fetches_fn=None, optimizer_fn=None, validate_spaces=None, before_init=None, after_init=None, action_sampler_fn=None, action_distribution_fn=None, make_model=None, make_model_and_action_dist=None, apply_gradients_fn=None, mixins=None, get_batch_divisibility_req=None): """Helper function for creating a torch policy class at runtime. Arguments: name (str): name of the policy (e.g., "PPOTorchPolicy") loss_fn (callable): Callable that returns a loss tensor as arguments given (policy, model, dist_class, train_batch). get_default_config (Optional[callable]): Optional callable that returns the default config to merge with any overrides. stats_fn (Optional[callable]): Optional callable that returns a dict of values given the policy and batch input tensors. postprocess_fn (Optional[callable]): Optional experience postprocessing function that takes the same args as Policy.postprocess_trajectory(). extra_action_out_fn (Optional[callable]): Optional callable that returns a dict of extra values to include in experiences. extra_grad_process_fn (Optional[callable]): Optional callable that is called after gradients are computed and returns processing info. extra_learn_fetches_fn (func): optional function that returns a dict of extra values to fetch from the policy after loss evaluation. optimizer_fn (Optional[callable]): Optional callable that returns a torch optimizer given the policy and config. validate_spaces (Optional[callable]): Optional callable that takes the Policy, observation_space, action_space, and config to check for correctness. before_init (Optional[callable]): Optional callable to run at the beginning of `Policy.__init__` that takes the same arguments as the Policy constructor. after_init (Optional[callable]): Optional callable to run at the end of policy init that takes the same arguments as the policy constructor. action_sampler_fn (Optional[callable]): Optional callable returning a sampled action and its log-likelihood given some (obs and state) inputs. action_distribution_fn (Optional[callable]): A callable that takes the Policy, Model, the observation batch, an explore-flag, a timestep, and an is_training flag and returns a tuple of a) distribution inputs (parameters), b) a dist-class to generate an action distribution object from, and c) internal-state outputs (empty list if not applicable). make_model (Optional[callable]): Optional func that takes the same arguments as Policy.__init__ and returns a model instance. The distribution class will be determined automatically. Note: Only one of `make_model` or `make_model_and_action_dist` should be provided. make_model_and_action_dist (Optional[callable]): Optional func that takes the same arguments as Policy.__init__ and returns a tuple of model instance and torch action distribution class. Note: Only one of `make_model` or `make_model_and_action_dist` should be provided. apply_gradients_fn (Optional[callable]): Optional callable that takes a grads list and applies these to the Model's parameters. mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class. get_batch_divisibility_req (Optional[callable]): Optional callable that returns the divisibility requirement for sample batches. Returns: type: TorchPolicy child class constructed from the specified args. """ original_kwargs = locals().copy() base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if validate_spaces: validate_spaces(self, obs_space, action_space, self.config) if before_init: before_init(self, obs_space, action_space, self.config) # Model is customized (use default action dist class). if make_model: assert make_model_and_action_dist is None, \ "Either `make_model` or `make_model_and_action_dist`" \ " must be None!" self.model = make_model(self, obs_space, action_space, config) dist_class, _ = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") # Model and action dist class are customized. elif make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Use default model and default action dist. else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="torch", **self.config["model"].get("custom_model_config", {})) # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: Generated Model must be a TorchModelV2 object!" TorchPolicy.__init__( self, observation_space=obs_space, action_space=action_space, config=config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) if after_init: after_init(self, obs_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): # Do all post-processing always with no_grad(). # Not using this here will introduce a memory leak (issue #6962). with torch.no_grad(): # Call super's postprocess_trajectory first. sample_batch = super().postprocess_trajectory( convert_to_non_torch_type(sample_batch), convert_to_non_torch_type(other_agent_batches), episode) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(TorchPolicy) def extra_grad_process(self, optimizer, loss): """Called after optimizer.zero_grad() and loss.backward() calls. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping. """ if extra_grad_process_fn: return extra_grad_process_fn(self, optimizer, loss) else: return TorchPolicy.extra_grad_process(self, optimizer, loss) @override(TorchPolicy) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: fetches = convert_to_non_torch_type( extra_learn_fetches_fn(self)) # Auto-add empty learner stats dict if needed. return dict({LEARNER_STATS_KEY: {}}, **fetches) else: return TorchPolicy.extra_compute_grad_fetches(self) @override(TorchPolicy) def apply_gradients(self, gradients): if apply_gradients_fn: apply_gradients_fn(self, gradients) else: TorchPolicy.apply_gradients(self, gradients) @override(TorchPolicy) def extra_action_out(self, input_dict, state_batches, model, action_dist): with torch.no_grad(): if extra_action_out_fn: stats_dict = extra_action_out_fn( self, input_dict, state_batches, model, action_dist) else: stats_dict = TorchPolicy.extra_action_out( self, input_dict, state_batches, model, action_dist) return convert_to_non_torch_type(stats_dict) @override(TorchPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return TorchPolicy.optimizer(self) @override(TorchPolicy) def extra_grad_info(self, train_batch): with torch.no_grad(): if stats_fn: stats_dict = stats_fn(self, train_batch) else: stats_dict = TorchPolicy.extra_grad_info(self, train_batch) return convert_to_non_torch_type(stats_dict) def with_updates(**overrides): return build_torch_policy(**dict(original_kwargs, **overrides)) policy_cls.with_updates = staticmethod(with_updates) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_trainer( name, default_policy, default_config=None, validate_config=None, get_initial_state=None, # DEPRECATED get_policy_class=None, before_init=None, make_workers=None, # DEPRECATED make_policy_optimizer=None, # DEPRECATED after_init=None, before_train_step=None, # DEPRECATED after_optimizer_step=None, # DEPRECATED after_train_result=None, # DEPRECATED collect_metrics_fn=None, # DEPRECATED before_evaluate_fn=None, mixins=None, execution_plan=default_execution_plan): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_policy 2. Worker setup: before_init, execution_plan 3. Post setup: after_init Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): The default config dict of the algorithm, otherwise uses the Trainer default config. validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. get_policy_class (func): optional callback that takes a config and returns the policy class to override the default with before_init (func): optional function to run at the start of trainer init that takes the trainer instance as argument after_init (func): optional function to run at the end of trainer init that takes the trainer instance as argument before_evaluate_fn (func): callback to run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class execution_plan (func): Setup the distributed execution workflow. Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_initial_state: deprecation_warning("get_initial_state", "execution_plan") self.state = get_initial_state(self) else: self.state = {} if get_policy_class is None: self._policy = default_policy else: self._policy = get_policy_class(config) if before_init: before_init(self) # Creating all workers (excluding evaluation workers). if make_workers and not execution_plan: deprecation_warning("make_workers", "execution_plan") self.workers = make_workers(self, env_creator, self._policy, config) else: self.workers = self._make_workers(env_creator, self._policy, config, self.config["num_workers"]) self.train_exec_impl = None self.optimizer = None self.execution_plan = execution_plan if make_policy_optimizer: deprecation_warning("make_policy_optimizer", "execution_plan") self.optimizer = make_policy_optimizer(self.workers, config) else: assert execution_plan is not None self.train_exec_impl = execution_plan(self.workers, config) if after_init: after_init(self) @override(Trainer) def _train(self): if self.train_exec_impl: return self._train_exec_impl() if before_train_step: deprecation_warning("before_train_step", "execution_plan") before_train_step(self) prev_steps = self.optimizer.num_steps_sampled start = time.time() optimizer_steps_this_iter = 0 while True: fetches = self.optimizer.step() optimizer_steps_this_iter += 1 if after_optimizer_step: deprecation_warning("after_optimizer_step", "execution_plan") after_optimizer_step(self, fetches) if (time.time() - start >= self.config["min_iter_time_s"] and self.optimizer.num_steps_sampled - prev_steps >= self.config["timesteps_per_iteration"]): break if collect_metrics_fn: deprecation_warning("collect_metrics_fn", "execution_plan") res = collect_metrics_fn(self) else: res = self.collect_metrics() res.update( optimizer_steps_this_iter=optimizer_steps_this_iter, timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) if after_train_result: deprecation_warning("after_train_result", "execution_plan") after_train_result(self, res) return res def _train_exec_impl(self): res = next(self.train_exec_impl) return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() if self.train_exec_impl: state["train_exec_impl"] = ( self.train_exec_impl.shared_metrics.get().save()) return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() if self.train_exec_impl: self.train_exec_impl.shared_metrics.get().restore( state["train_exec_impl"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. Arguments: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = staticmethod(with_updates) trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
"GPU": cf["num_gpus"], }, { # Different bundle (meaning: possibly different node) # for your n "remote" envs (set remote_worker_envs=True). "CPU": cf["num_envs_per_worker"], } ], strategy=config.get("placement_strategy", "PACK")) # The modified Trainer class we will use. This is the exact same # as a PPOTrainer, but with the additional default_resource_request # override, telling tune that it's ok (not mandatory) to place our # n remote envs on a different node (each env using 1 CPU). PPOTrainerRemoteInference = add_mixins(PPOTrainer, [OverrideDefaultResourceRequest]) if __name__ == "__main__": args = get_cli_args() ray.init(num_cpus=6, local_mode=args.local_mode) config = { "env": "CartPole-v0", # Force sub-envs to be ray.actor.ActorHandles, so we can step # through them in parallel. "remote_worker_envs": True, # Set the number of CPUs used by the (local) worker, aka "driver" # to match the number of ray remote envs. "num_cpus_for_driver": args.num_envs_per_worker + 1, # Use a single worker (however, with n parallelized remote envs, maybe
def build_eager_tf_policy(name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_learn_fetches_fn=None, extra_action_fetches_fn=None, validate_spaces=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, action_distribution_fn=None, mixins=None, obs_include_prev_action_reward=True, get_batch_divisibility_req=None): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging much simpler, but has lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"framework": "tfe"} in the trainer config to have it automatically be converted to an eager policy. This has the same signature as build_tf_policy().""" base = add_mixins(Policy, mixins) class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) self.exploration = self._create_exploration() self._state_in = [ tf.convert_to_tensor([s]) for s in self.model.get_initial_state() ] input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(np.array([observation_space.sample()])), SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( [flatten_to_single_ndarray(action_space.sample())]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]), } if action_distribution_fn: dist_inputs, self.dist_class, _ = action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS]) else: self.model(input_dict, self._state_in, tf.convert_to_tensor([1])) if before_loss_init: before_loss_init(self, observation_space, action_space, config) self._initialize_loss_with_dummy_batch() self._loss_initialized = True if optimizer_fn: self._optimizer = optimizer_fn(self, config) else: self._optimizer = tf.keras.optimizers.Adam(config["lr"]) if after_init: after_init(self, observation_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): assert tf.executing_eagerly() # Call super's postprocess_trajectory first. sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(Policy) @convert_eager_inputs @convert_eager_outputs def learn_on_batch(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) self._apply_gradients(grads_and_vars) return stats @override(Policy) @convert_eager_inputs @convert_eager_outputs def compute_gradients(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) grads = [g for g, v in grads_and_vars] return grads, stats @override(Policy) @convert_eager_inputs @convert_eager_outputs def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): explore = explore if explore is not None else \ self.config["explore"] timestep = timestep if timestep is not None else \ self.global_timestep # TODO: remove python side effect to cull sources of bugs. self._is_training = False self._state_in = state_batches if not tf1.executing_eagerly(): tf1.enable_eager_execution() input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } n = input_dict[SampleBatch.CUR_OBS].shape[0] seq_lens = tf.ones(n, dtype=tf.int32) if obs_include_prev_action_reward: if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = \ tf.convert_to_tensor(prev_reward_batch) # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): if action_sampler_fn: dist_inputs = None state_out = [] actions, logp = self.action_sampler_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, episodes=episodes) else: # Exploration hook before each forward pass. self.exploration.before_compute_actions(timestep=timestep, explore=explore) if action_distribution_fn: dist_inputs, dist_class, state_out = \ action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, is_training=False) else: dist_class = self.dist_class dist_inputs, state_out = self.model( input_dict, state_batches, seq_lens) action_dist = dist_class(dist_inputs, self.model) # Get the exploration action from the forward results. actions, logp = self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore) # Add default and custom fetches. extra_fetches = {} # Action-logp and action-prob. if logp is not None: extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) extra_fetches[SampleBatch.ACTION_LOGP] = logp # Action-dist inputs. if dist_inputs is not None: extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs # Custom extra fetches. if extra_action_fetches_fn: extra_fetches.update(extra_action_fetches_fn(self)) # Update our global timestep by the batch size. self.global_timestep += len(obs_batch) return actions, state_out, extra_fetches @override(Policy) def compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None): if action_sampler_fn and action_distribution_fn is None: raise ValueError("Cannot compute log-prob/likelihood w/o an " "`action_distribution_fn` and a provided " "`action_sampler_fn`!") seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } if obs_include_prev_action_reward: input_dict.update({ SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(prev_action_batch), SampleBatch.PREV_REWARDS: tf.convert_to_tensor(prev_reward_batch), }) # Exploration hook before each forward pass. self.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if action_distribution_fn: dist_inputs, dist_class, _ = action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=False, is_training=False) action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(actions) # Default log-likelihood calculation. else: dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) dist_class = self.dist_class action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(actions) return log_likelihoods @override(Policy) def apply_gradients(self, gradients): self._apply_gradients( zip([(tf.convert_to_tensor(g) if g is not None else None) for g in gradients], self.model.trainable_variables())) @override(Policy) def get_exploration_info(self): return _convert_to_numpy(self.exploration.get_info()) @override(Policy) def get_weights(self, as_dict=False): variables = self.variables() if as_dict: return {v.name: v.numpy() for v in variables} return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): variables = self.variables() assert len(weights) == len(variables), (len(weights), len(variables)) for v, w in zip(variables, weights): v.assign(w) @override(Policy) def get_state(self): state = {"_state": super().get_state()} state["_optimizer_variables"] = self._optimizer.variables() return state @override(Policy) def set_state(self, state): state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) if optimizer_vars and self._optimizer.variables(): logger.warning( "Cannot restore an optimizer's state for tf eager! Keras " "is not able to save the v1.x optimizers (from " "tf.compat.v1.train) since they aren't compatible with " "checkpoints.") for opt_var, value in zip(self._optimizer.variables(), optimizer_vars): opt_var.assign(value) # Then the Policy's (NN) weights. super().set_state(state["_state"]) def variables(self): """Return the list of all savable variables for this policy.""" return self.model.variables() @override(Policy) def is_recurrent(self): return len(self._state_in) > 0 @override(Policy) def num_state_tensors(self): return len(self._state_in) @override(Policy) def get_initial_state(self): return self.model.get_initial_state() def get_session(self): return None # None implies eager def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " "rllib.utils.tf_ops.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): return self._loss_initialized @override(Policy) def export_model(self, export_dir): pass @override(Policy) def export_checkpoint(self, export_dir): pass def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) def _apply_gradients(self, grads_and_vars): if apply_gradients_fn: apply_gradients_fn(self, self._optimizer, grads_and_vars) else: self._optimizer.apply_gradients(grads_and_vars) def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" self._is_training = True with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state-in properly state_in = [] for i in range(self.num_state_tensors()): state_in.append(samples["state_in_{}".format(i)]) self._state_in = state_in self._seq_lens = None if len(state_in) > 0: self._seq_lens = tf.ones( samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32) samples["seq_lens"] = self._seq_lens model_out, _ = self.model(samples, self._state_in, self._seq_lens) loss = loss_fn(self, self.model, self.dist_class, samples) variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper: def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats def _stats(self, outputs, samples, grads): fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { k: v for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: fetches.update( {k: v for k, v in extra_learn_fetches_fn(self).items()}) if grad_stats_fn: fetches.update({ k: v for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches def _initialize_loss_with_dummy_batch(self): # Dummy forward pass to initialize any policy attributes, etc. dummy_batch = { SampleBatch.CUR_OBS: np.array([self.observation_space.sample()]), SampleBatch.NEXT_OBS: np.array([self.observation_space.sample()]), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if isinstance(self.action_space, (Dict, Tuple)): dummy_batch[SampleBatch.ACTIONS] = [ flatten_to_single_ndarray(self.action_space.sample()) ] else: dummy_batch[SampleBatch.ACTIONS] = tf.nest.map_structure( lambda c: np.array([c]), self.action_space.sample()) if obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: dummy_batch[SampleBatch.ACTIONS], SampleBatch.PREV_REWARDS: dummy_batch[SampleBatch.REWARDS], }) for i, h in enumerate(self._state_in): dummy_batch["state_in_{}".format(i)] = h dummy_batch["state_out_{}".format(i)] = h if self._state_in: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) # Convert everything to tensors. dummy_batch = tf.nest.map_structure(tf1.convert_to_tensor, dummy_batch) # for IMPALA which expects a certain sample batch size. def tile_to(tensor, n): return tf.tile(tensor, [n] + [1 for _ in tensor.shape.as_list()[1:]]) if get_batch_divisibility_req: dummy_batch = tf.nest.map_structure( lambda c: tile_to(c, get_batch_divisibility_req(self)), dummy_batch) i = 0 self._state_in = [] while "state_in_{}".format(i) in dummy_batch: self._state_in.append(dummy_batch["state_in_{}".format(i)]) i += 1 # Execute a forward pass to get self.action_dist etc initialized, # and also obtain the extra action fetches _, _, fetches = self.compute_actions( dummy_batch[SampleBatch.CUR_OBS], self._state_in, dummy_batch.get(SampleBatch.PREV_ACTIONS), dummy_batch.get(SampleBatch.PREV_REWARDS), explore=False) # Got to reset global_timestep again after this fake run-through. self.global_timestep = 0 dummy_batch.update(fetches) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model.from_batch(dummy_batch) postprocessed_batch = tf.nest.map_structure( lambda c: tf.convert_to_tensor(c), postprocessed_batch.data) loss_fn(self, self.model, self.dist_class, postprocessed_batch) if stats_fn: stats_fn(self, postprocessed_batch) @classmethod def with_tracing(cls): return traced_eager_policy(cls) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls
def build_tf_policy( name: str, *, loss_fn: Callable[ [Policy, ModelV2, Type[TFActionDistribution], SampleBatch], Union[TensorType, List[TensorType]]], get_default_config: Optional[Callable[[None], TrainerConfigDict]] = None, postprocess_fn: Optional[Callable[[ Policy, SampleBatch, Optional[Dict[ AgentID, SampleBatch]], Optional["Episode"] ], SampleBatch]] = None, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, optimizer_fn: Optional[ Callable[[Policy, TrainerConfigDict], "tf.keras.optimizers.Optimizer"]] = None, compute_gradients_fn: Optional[ Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]] = None, apply_gradients_fn: Optional[ Callable[[Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"]] = None, grad_stats_fn: Optional[Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]] = None, extra_action_out_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, validate_spaces: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_loss_init: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]] = None, after_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, make_model: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]] = None, action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[ Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]] = None, mixins: Optional[List[type]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, # Deprecated args. obs_include_prev_action_reward=DEPRECATED_VALUE, extra_action_fetches_fn=None, # Use `extra_action_out_fn`. gradients_fn=None, # Use `compute_gradients_fn`. ) -> Type[DynamicTFPolicy]: """Helper function for creating a dynamic tf policy at runtime. Functions will be run in this order to initialize the policy: 1. Placeholder setup: postprocess_fn 2. Loss init: loss_fn, stats_fn 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn, grad_stats_fn This means that you can e.g., depend on any policy attributes created in the running of `loss_fn` in later functions such as `stats_fn`. In eager mode, the following functions will be run repeatedly on each eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn, and grad_stats_fn. This means that these functions should not define any variables internally, otherwise they will fail in eager mode execution. Variable should only be created in make_model (if defined). Args: name (str): Name of the policy (e.g., "PPOTFPolicy"). loss_fn (Callable[[ Policy, ModelV2, Type[TFActionDistribution], SampleBatch], Union[TensorType, List[TensorType]]]): Callable for calculating a loss tensor. get_default_config (Optional[Callable[[None], TrainerConfigDict]]): Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided PartialTrainerConfigDict as dict for this Policy. postprocess_fn (Optional[Callable[[Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]], Episode], None]]): Optional callable for post-processing experience batches (called after the parent class' `postprocess_trajectory` method). stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional callable that returns a dict of TF tensors to fetch given the policy and batch input tensors. If None, will not compute any stats. optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict], "tf.keras.optimizers.Optimizer"]]): Optional callable that returns a tf.Optimizer given the policy and config. If None, will call the base class' `optimizer()` method instead (which returns a tf1.train.AdamOptimizer). compute_gradients_fn (Optional[Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]): Optional callable that returns a list of gradients. If None, this defaults to optimizer.compute_gradients([loss]). apply_gradients_fn (Optional[Callable[[Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"]]): Optional callable that returns an apply gradients op given policy, tf-optimizer, and grads_and_vars. If None, will call the base class' `build_apply_op()` method instead. grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]]): Optional callable that returns a dict of TF fetches given the policy, batch input, and gradient tensors. If None, will not collect any gradient stats. extra_action_out_fn (Optional[Callable[[Policy], Dict[str, TensorType]]]): Optional callable that returns a dict of TF fetches given the policy object. If None, will not perform any extra fetches. extra_learn_fetches_fn (Optional[Callable[[Policy], Dict[str, TensorType]]]): Optional callable that returns a dict of extra values to fetch and return when learning on a batch. If None, will call the base class' `extra_compute_grad_fetches()` method instead. validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable that takes the Policy, observation_space, action_space, and config to check the spaces for correctness. If None, no spaces checking will be done. before_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the beginning of policy init that takes the same arguments as the policy constructor. If None, this step will be skipped. before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to run prior to loss init. If None, this step will be skipped. after_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the end of policy init. If None, this step will be skipped. make_model (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable that returns a ModelV2 object. All policy variables should be created in this function. If None, a default ModelV2 object will be created. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]]): A callable returning a sampled action and its log-likelihood given observation and state inputs. If None, will either use `action_distribution_fn` or compute actions by calling self.model, then sampling from the so parameterized action distribution. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): Optional callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). If None, will either use `action_sampler_fn` or compute actions by calling self.model, then sampling from the so parameterized action distribution. mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. Returns: Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the specified args. """ original_kwargs = locals().copy() base = add_mixins(DynamicTFPolicy, mixins) if obs_include_prev_action_reward != DEPRECATED_VALUE: deprecation_warning(old="obs_include_prev_action_reward", error=False) if extra_action_fetches_fn is not None: deprecation_warning(old="extra_action_fetches_fn", new="extra_action_out_fn", error=False) extra_action_out_fn = extra_action_fetches_fn if gradients_fn is not None: deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=False) compute_gradients_fn = gradients_fn class policy_cls(base): def __init__(self, obs_space, action_space, config, existing_model=None, existing_inputs=None): if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, obs_space, action_space, config) if before_init: before_init(self, obs_space, action_space, config) def before_loss_init_wrapper(policy, obs_space, action_space, config): if before_loss_init: before_loss_init(policy, obs_space, action_space, config) if extra_action_out_fn is None or policy._is_tower: extra_action_fetches = {} else: extra_action_fetches = extra_action_out_fn(policy) if hasattr(policy, "_extra_action_fetches"): policy._extra_action_fetches.update(extra_action_fetches) else: policy._extra_action_fetches = extra_action_fetches DynamicTFPolicy.__init__( self, obs_space=obs_space, action_space=action_space, config=config, loss_fn=loss_fn, stats_fn=stats_fn, grad_stats_fn=grad_stats_fn, before_loss_init=before_loss_init_wrapper, make_model=make_model, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, existing_inputs=existing_inputs, existing_model=existing_model, get_batch_divisibility_req=get_batch_divisibility_req, ) if after_init: after_init(self, obs_space, action_space, config) # Got to reset global_timestep again after this fake run-through. self.global_timestep = 0 @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): # Call super's postprocess_trajectory first. sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(TFPolicy) def optimizer(self): if optimizer_fn: optimizers = optimizer_fn(self, self.config) else: optimizers = base.optimizer(self) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) # No optimizers produced -> Return None. if not optimizers: return None # New API: Allow more than one optimizer to be returned. # -> Return list. elif self.config["_tf_policy_handles_more_than_one_loss"]: return optimizers # Old API: Return a single LocalOptimizer. else: return optimizers[0] @override(TFPolicy) def gradients(self, optimizer, loss): optimizers = force_list(optimizer) losses = force_list(loss) if compute_gradients_fn: # New API: Allow more than one optimizer -> Return a list of # lists of gradients. if self.config["_tf_policy_handles_more_than_one_loss"]: return compute_gradients_fn(self, optimizers, losses) # Old API: Return a single List of gradients. else: return compute_gradients_fn(self, optimizers[0], losses[0]) else: return base.gradients(self, optimizers, losses) @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): if apply_gradients_fn: return apply_gradients_fn(self, optimizer, grads_and_vars) else: return base.build_apply_op(self, optimizer, grads_and_vars) @override(TFPolicy) def extra_compute_action_fetches(self): return dict(base.extra_compute_action_fetches(self), **self._extra_action_fetches) @override(TFPolicy) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: # TODO: (sven) in torch, extra_learn_fetches do not exist. # Hence, things like td_error are returned by the stats_fn # and end up under the LEARNER_STATS_KEY. We should # change tf to do this as well. However, this will confilct # the handling of LEARNER_STATS_KEY inside the multi-GPU # train op. # Auto-add empty learner stats dict if needed. return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self)) else: return base.extra_compute_grad_fetches(self) def with_updates(**overrides): """Allows creating a TFPolicy cls based on settings of another one. Keyword Args: **overrides: The settings (passed into `build_tf_policy`) that should be different from the class that this method is called on. Returns: type: A new TFPolicy sub-class. Examples: >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates( .. name="MySpecialDQNPolicyClass", .. loss_function=[some_new_loss_function], .. ) """ return build_tf_policy(**dict(original_kwargs, **overrides)) def as_eager(): return eager_tf_policy.build_eager_tf_policy(**original_kwargs) policy_cls.with_updates = staticmethod(with_updates) policy_cls.as_eager = staticmethod(as_eager) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_torch_policy(name, loss_fn, get_default_config=None, stats_fn=None, postprocess_fn=None, extra_action_out_fn=None, extra_grad_process_fn=None, optimizer_fn=None, before_init=None, after_init=None, make_model_and_action_dist=None, mixins=None): """Helper function for creating a torch policy at runtime. Arguments: name (str): name of the policy (e.g., "PPOTorchPolicy") loss_fn (func): function that returns a loss tensor the policy, and dict of experience tensor placeholders get_default_config (func): optional function that returns the default config to merge with any overrides stats_fn (func): optional function that returns a dict of values given the policy and batch input tensors postprocess_fn (func): optional experience postprocessing function that takes the same args as Policy.postprocess_trajectory() extra_action_out_fn (func): optional function that returns a dict of extra values to include in experiences extra_grad_process_fn (func): optional function that is called after gradients are computed and returns processing info optimizer_fn (func): optional function that returns a torch optimizer given the policy and config before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor after_init (func): optional function to run at the end of policy init that takes the same arguments as the policy constructor make_model_and_action_dist (func): optional func that takes the same arguments as policy init and returns a tuple of model instance and torch action distribution class. If not specified, the default model and action dist from the catalog will be used mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class Returns: a TorchPolicy instance that uses the specified args """ original_kwargs = locals().copy() base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if before_init: before_init(self, obs_space, action_space, config) if make_model_and_action_dist: self.model, self.dist_class = make_model_and_action_dist( self, obs_space, action_space, config) else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], torch=True) self.model = ModelCatalog.get_model_v2( obs_space, action_space, logit_dim, self.config["model"], framework="torch") TorchPolicy.__init__(self, obs_space, action_space, self.model, loss_fn, self.dist_class) if after_init: after_init(self, obs_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): if not postprocess_fn: return sample_batch return postprocess_fn(self, sample_batch, other_agent_batches, episode) @override(TorchPolicy) def extra_grad_process(self): if extra_grad_process_fn: return extra_grad_process_fn(self) else: return TorchPolicy.extra_grad_process(self) @override(TorchPolicy) def extra_action_out(self, input_dict, state_batches, model): if extra_action_out_fn: return extra_action_out_fn(self, input_dict, state_batches, model) else: return TorchPolicy.extra_action_out(self, input_dict, state_batches, model) @override(TorchPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return TorchPolicy.optimizer(self) @override(TorchPolicy) def extra_grad_info(self, batch_tensors): if stats_fn: return stats_fn(self, batch_tensors) else: return TorchPolicy.extra_grad_info(self, batch_tensors) @staticmethod def with_updates(**overrides): return build_torch_policy(**dict(original_kwargs, **overrides)) policy_cls.with_updates = with_updates policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_policy_class( name: str, framework: str, *, loss_fn: Optional[Callable[ [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, List[TensorType]]]], get_default_config: Optional[Callable[[], TrainerConfigDict]] = None, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, postprocess_fn: Optional[Callable[[ Policy, SampleBatch, Optional[Dict[ Any, SampleBatch]], Optional["MultiAgentEpisode"] ], SampleBatch]] = None, extra_action_out_fn: Optional[Callable[[ Policy, Dict[ str, TensorType], List[TensorType], ModelV2, TorchDistributionWrapper ], Dict[str, TensorType]]] = None, extra_grad_process_fn: Optional[ Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]] = None, # TODO: (sven) Replace "fetches" with "process". extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict], "torch.optim.Optimizer"]] = None, validate_spaces: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, before_loss_init: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]] = None, after_init: Optional[Callable[ [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, _after_loss_init: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]] = None, action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[ Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]] = None, make_model: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]] = None, make_model_and_action_dist: Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None, apply_gradients_fn: Optional[Callable[[Policy, "torch.optim.Optimizer"], None]] = None, mixins: Optional[List[type]] = None, view_requirements_fn: Optional[Callable[[Policy], Dict[str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None ) -> Type[TorchPolicy]: """Helper function for creating a new Policy class at runtime. Supports frameworks JAX and PyTorch. Args: name (str): name of the policy (e.g., "PPOTorchPolicy") framework (str): Either "jax" or "torch". loss_fn (Optional[Callable[[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, List[TensorType]]]]): Callable that returns a loss tensor. get_default_config (Optional[Callable[[None], TrainerConfigDict]]): Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided PartialTrainerConfigDict as dict for this Policy. postprocess_fn (Optional[Callable[[Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]], SampleBatch]]): Optional callable for post-processing experience batches (called after the super's `postprocess_trajectory` method). stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional callable that returns a dict of values given the policy and training batch. If None, will use `TorchPolicy.extra_grad_info()` instead. The stats dict is used for logging (e.g. in TensorBoard). extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType], List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, TensorType]]]): Optional callable that returns a dict of extra values to include in experiences. If None, no extra computations will be performed. extra_grad_process_fn (Optional[Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]): Optional callable that is called after gradients are computed and returns a processing info dict. If None, will call the `TorchPolicy.extra_grad_process()` method instead. # TODO: (sven) dissolve naming mismatch between "learn" and "compute.." extra_learn_fetches_fn (Optional[Callable[[Policy], Dict[str, TensorType]]]): Optional callable that returns a dict of extra tensors from the policy after loss evaluation. If None, will call the `TorchPolicy.extra_compute_grad_fetches()` method instead. optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict], "torch.optim.Optimizer"]]): Optional callable that returns a torch optimizer given the policy and config. If None, will call the `TorchPolicy.optimizer()` method instead (which returns a torch Adam optimizer). validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable that takes the Policy, observation_space, action_space, and config to check for correctness. If None, no spaces checking will be done. before_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): Optional callable to run at the beginning of `Policy.__init__` that takes the same arguments as the Policy constructor. If None, this step will be skipped. before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to run prior to loss init. If None, this step will be skipped. after_init (Optional[Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init` instead. _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to run after the loss init. If None, this step will be skipped. This will be deprecated at some point and renamed into `after_init` to match `build_tf_policy()` behavior. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]]): Optional callable returning a sampled action and its log-likelihood given some (obs and state) inputs. If None, will either use `action_distribution_fn` or compute actions by calling self.model, then sampling from the so parameterized action distribution. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]]]): A callable that takes the Policy, Model, the observation batch, an explore-flag, a timestep, and an is_training flag and returns a tuple of a) distribution inputs (parameters), b) a dist-class to generate an action distribution object from, and c) internal-state outputs (empty list if not applicable). If None, will either use `action_sampler_fn` or compute actions by calling self.model, then sampling from the parameterized action distribution. make_model (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable that takes the same arguments as Policy.__init__ and returns a model instance. The distribution class will be determined automatically. Note: Only one of `make_model` or `make_model_and_action_dist` should be provided. If both are None, a default Model will be created. make_model_and_action_dist (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional callable that takes the same arguments as Policy.__init__ and returns a tuple of model instance and torch action distribution class. Note: Only one of `make_model` or `make_model_and_action_dist` should be provided. If both are None, a default Model will be created. apply_gradients_fn (Optional[Callable[[Policy, "torch.optim.Optimizer"], None]]): Optional callable that takes a grads list and applies these to the Model's parameters. If None, will call the `TorchPolicy.apply_gradients()` method instead. mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class. view_requirements_fn (Optional[Callable[[Policy], Dict[str, ViewRequirement]]]): An optional callable to retrieve additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. Returns: Type[TorchPolicy]: TorchPolicy child class constructed from the specified args. """ original_kwargs = locals().copy() parent_cls = TorchPolicy base = add_mixins(parent_cls, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): # Set up the config from possible default-config fn and given # config arg. if get_default_config: config = dict(get_default_config(), **config) self.config = config # Set the DL framework for this Policy. self.framework = self.config["framework"] = framework # Validate observation- and action-spaces. if validate_spaces: validate_spaces(self, obs_space, action_space, self.config) # Do some pre-initialization steps. if before_init: before_init(self, obs_space, action_space, self.config) # Model is customized (use default action dist class). if make_model: assert make_model_and_action_dist is None, \ "Either `make_model` or `make_model_and_action_dist`" \ " must be None!" self.model = make_model(self, obs_space, action_space, config) dist_class, _ = ModelCatalog.get_action_dist( action_space, self.config["model"], framework=framework) # Model and action dist class are customized. elif make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Use default model and default action dist. else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework=framework) self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework=framework) # Make sure, we passed in a correct Model factory. model_cls = TorchModelV2 if framework == "torch" else JAXModelV2 assert isinstance(self.model, model_cls), \ "ERROR: Generated Model must be a TorchModelV2 object!" # Call the framework-specific Policy constructor. self.parent_cls = parent_cls self.parent_cls.__init__( self, observation_space=obs_space, action_space=action_space, config=config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) # Update this Policy's ViewRequirements (if function given). if callable(view_requirements_fn): self.view_requirements.update(view_requirements_fn(self)) # Merge Model's view requirements into Policy's. self.view_requirements.update( self.model.inference_view_requirements) _before_loss_init = before_loss_init or after_init if _before_loss_init: _before_loss_init(self, self.observation_space, self.action_space, config) # Perform test runs through postprocessing- and loss functions. self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) if _after_loss_init: _after_loss_init(self, obs_space, action_space, config) # Got to reset global_timestep again after this fake run-through. self.global_timestep = 0 @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): # Do all post-processing always with no_grad(). # Not using this here will introduce a memory leak # in torch (issue #6962). with self._no_grad_context(): # Call super's postprocess_trajectory first. sample_batch = super().postprocess_trajectory( sample_batch, other_agent_batches, episode) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(parent_cls) def extra_grad_process(self, optimizer, loss): """Called after optimizer.zero_grad() and loss.backward() calls. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping. """ if extra_grad_process_fn: return extra_grad_process_fn(self, optimizer, loss) else: return parent_cls.extra_grad_process(self, optimizer, loss) @override(parent_cls) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: fetches = convert_to_non_torch_type( extra_learn_fetches_fn(self)) # Auto-add empty learner stats dict if needed. return dict({LEARNER_STATS_KEY: {}}, **fetches) else: return parent_cls.extra_compute_grad_fetches(self) @override(parent_cls) def apply_gradients(self, gradients): if apply_gradients_fn: apply_gradients_fn(self, gradients) else: parent_cls.apply_gradients(self, gradients) @override(parent_cls) def extra_action_out(self, input_dict, state_batches, model, action_dist): with self._no_grad_context(): if extra_action_out_fn: stats_dict = extra_action_out_fn(self, input_dict, state_batches, model, action_dist) else: stats_dict = parent_cls.extra_action_out( self, input_dict, state_batches, model, action_dist) return self._convert_to_non_torch_type(stats_dict) @override(parent_cls) def optimizer(self): if optimizer_fn: optimizers = optimizer_fn(self, self.config) else: optimizers = parent_cls.optimizer(self) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) return optimizers @override(parent_cls) def extra_grad_info(self, train_batch): with self._no_grad_context(): if stats_fn: stats_dict = stats_fn(self, train_batch) else: stats_dict = self.parent_cls.extra_grad_info( self, train_batch) return self._convert_to_non_torch_type(stats_dict) def _no_grad_context(self): if self.framework == "torch": return torch.no_grad() return NullContextManager() def _convert_to_non_torch_type(self, data): if self.framework == "torch": return convert_to_non_torch_type(data) return data def with_updates(**overrides): """Creates a Torch|JAXPolicy cls based on settings of another one. Keyword Args: **overrides: The settings (passed into `build_torch_policy`) that should be different from the class that this method is called on. Returns: type: A new Torch|JAXPolicy sub-class. Examples: >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates( .. name="MySpecialDQNPolicyClass", .. loss_function=[some_new_loss_function], .. ) """ return build_policy_class(**dict(original_kwargs, **overrides)) policy_cls.with_updates = staticmethod(with_updates) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_eager_tf_policy(name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_learn_fetches_fn=None, extra_action_fetches_fn=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, mixins=None, obs_include_prev_action_reward=True, get_batch_divisibility_req=None): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging much simpler, but is lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"eager": true} in the trainer config to have it automatically be converted to an eager policy. This has the same signature as build_tf_policy().""" base = add_mixins(Policy, mixins) class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None if get_default_config: config = dict(get_default_config(), **config) if before_init: before_init(self, observation_space, action_space, config) self.config = config if action_sampler_fn: if not make_model: raise ValueError( "make_model is required if action_sampler_fn is given") self.dist_class = None else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework="tf", ) self.model( { SampleBatch.CUR_OBS: tf.convert_to_tensor(np.array([observation_space.sample() ])), SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( [_flatten_action(action_space.sample())]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]), }, [ tf.convert_to_tensor([s]) for s in self.model.get_initial_state() ], tf.convert_to_tensor([1])) if before_loss_init: before_loss_init(self, observation_space, action_space, config) self._initialize_loss_with_dummy_batch() self._loss_initialized = True if optimizer_fn: self._optimizer = optimizer_fn(self, config) else: self._optimizer = tf.train.AdamOptimizer(config["lr"]) if after_init: after_init(self, observation_space, action_space, config) @override(Policy) def postprocess_trajectory(self, samples, other_agent_batches=None, episode=None): assert tf.executing_eagerly() if postprocess_fn: return postprocess_fn(self, samples, other_agent_batches, episode) else: return samples @override(Policy) def learn_on_batch(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) self._apply_gradients(grads_and_vars) return stats @override(Policy) def compute_gradients(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) grads = [g for g, v in grads_and_vars] grads = [(g.numpy() if g is not None else None) for g in grads] return grads, stats @override(Policy) def compute_actions(self, obs_batch, state_batches, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): assert tf.executing_eagerly() self._is_training = False self._seq_lens = tf.ones(len(obs_batch)) self._input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.convert_to_tensor(False), } if obs_include_prev_action_reward: self._input_dict.update({ SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(prev_action_batch), SampleBatch.PREV_REWARDS: tf.convert_to_tensor(prev_reward_batch), }) self._state_in = state_batches with tf.variable_creator_scope(_disallow_var_creation): model_out, state_out = self.model(self._input_dict, state_batches, self._seq_lens) if self.dist_class: action_dist = self.dist_class(model_out, self.model) action = action_dist.sample().numpy() logp = action_dist.sampled_action_logp() else: action, logp = action_sampler_fn(self, self.model, self._input_dict, self.observation_space, self.action_space, self.config) action = action.numpy() fetches = {} if logp is not None: fetches.update({ ACTION_PROB: tf.exp(logp).numpy(), ACTION_LOGP: logp.numpy(), }) if extra_action_fetches_fn: fetches.update(extra_action_fetches_fn(self)) return action, state_out, fetches @override(Policy) def apply_gradients(self, gradients): self._apply_gradients( zip([(tf.convert_to_tensor(g) if g is not None else None) for g in gradients], self.model.trainable_variables())) @override(Policy) def get_weights(self): variables = self.model.variables() return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): variables = self.model.variables() assert len(weights) == len(variables), (len(weights), len(variables)) for v, w in zip(variables, weights): v.assign(w) def is_recurrent(self): return len(self._state_in) > 0 def num_state_tensors(self): return len(self._state_in) def get_session(self): return None # None implies eager def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " "rllib.utils.tf_ops.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): return self._loss_initialized def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) def _apply_gradients(self, grads_and_vars): if apply_gradients_fn: apply_gradients_fn(self, self._optimizer, grads_and_vars) else: self._optimizer.apply_gradients(grads_and_vars) def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" self._is_training = True samples = { k: tf.convert_to_tensor(v) for k, v in samples.items() if v.dtype != np.object } with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state in properly self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS])) self._state_in = [] model_out, _ = self.model(samples, self._state_in, self._seq_lens) loss = loss_fn(self, self.model, self.dist_class, samples) variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper(object): def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats def _stats(self, outputs, samples, grads): assert tf.executing_eagerly() fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { k: v.numpy() for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: fetches.update({ k: v.numpy() for k, v in extra_learn_fetches_fn(self).items() }) if grad_stats_fn: fetches.update({ k: v.numpy() for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches def _initialize_loss_with_dummy_batch(self): # Dummy forward pass to initialize any policy attributes, etc. action_dtype, action_shape = ModelCatalog.get_action_shape( self.action_space) dummy_batch = { SampleBatch.CUR_OBS: tf.convert_to_tensor( np.array([self.observation_space.sample()])), SampleBatch.NEXT_OBS: tf.convert_to_tensor( np.array([self.observation_space.sample()])), SampleBatch.DONES: tf.convert_to_tensor(np.array([False], dtype=np.bool)), SampleBatch.ACTIONS: tf.convert_to_tensor( np.zeros((1, ) + action_shape[1:], dtype=action_dtype.as_numpy_dtype())), SampleBatch.REWARDS: tf.convert_to_tensor(np.array([0], dtype=np.float32)), } if obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: dummy_batch[SampleBatch.ACTIONS], SampleBatch.PREV_REWARDS: dummy_batch[SampleBatch.REWARDS], }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = tf.convert_to_tensor( np.expand_dims(h, 0)) dummy_batch["state_out_{}".format(i)] = tf.convert_to_tensor( np.expand_dims(h, 0)) state_batches.append(tf.convert_to_tensor(np.expand_dims(h, 0))) if state_init: dummy_batch["seq_lens"] = tf.convert_to_tensor( np.array([1], dtype=np.int32)) # for IMPALA which expects a certain sample batch size def tile_to(tensor, n): return tf.tile(tensor, [n] + [1 for _ in tensor.shape.as_list()[1:]]) if get_batch_divisibility_req: dummy_batch = { k: tile_to(v, get_batch_divisibility_req(self)) for k, v in dummy_batch.items() } # Execute a forward pass to get self.action_dist etc initialized, # and also obtain the extra action fetches _, _, fetches = self.compute_actions( dummy_batch[SampleBatch.CUR_OBS], state_batches, dummy_batch.get(SampleBatch.PREV_ACTIONS), dummy_batch.get(SampleBatch.PREV_REWARDS)) dummy_batch.update(fetches) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model.from_batch(dummy_batch) postprocessed_batch = { k: tf.convert_to_tensor(v) for k, v in postprocessed_batch.items() } loss_fn(self, self.model, self.dist_class, postprocessed_batch) if stats_fn: stats_fn(self, postprocessed_batch) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls
def build_trainer(name, default_policy, default_config=None, validate_config=None, get_initial_state=None, get_policy_class=None, before_init=None, make_workers=None, make_policy_optimizer=None, after_init=None, before_train_step=None, after_optimizer_step=None, after_train_result=None, collect_metrics_fn=None, before_evaluate_fn=None, mixins=None, training_pipeline=None): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_initial_state, get_policy 2. Worker setup: before_init, make_workers, make_policy_optimizer 3. Post setup: after_init Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (Optional[dict]): The default config dict of the algorithm. If None, uses the Trainer default config. validate_config (Optional[callable]): Optional callback that checks a given config for correctness. It may mutate the config as needed. get_initial_state (Optional[callable]): Optional callable that returns the initial state dict given the trainer instance as an argument. The state dict must be serializable so that it can be checkpointed, and will be available as the `trainer.state` variable. get_policy_class (Optional[callable]): Optional callable that takes a Trainer config and returns the policy class to override the default with. before_init (Optional[callable]): Optional callable to run at the start of trainer init that takes the trainer instance as argument. make_workers (Optional[callable]): Override the default method that creates rollout workers. This takes in (trainer, env_creator, policy, config) as args. make_policy_optimizer (Optional[callable]): Optional callable that returns a PolicyOptimizer instance given (WorkerSet, config). after_init (Optional[callable]): Optional callable to run at the end of trainer init that takes the trainer instance as argument. before_train_step (Optional[callable]): Optional callable to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (Optional[callable]): Optional callable to run after each step() call to the policy optimizer. It takes the trainer instance and the policy gradient fetches as arguments. after_train_result (Optional[callable]): Optional callable to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed. collect_metrics_fn (Optional[callable]): Optional callable to override the default method used to collect metrics. Takes the trainer instance as argumnt. before_evaluate_fn (Optional[callable]): Optional callable to run before evaluation. Takes the trainer instance as argument. mixins (Optional[List[class]]): Optional list of mixin class(es) for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. training_pipeline (Optional[callable]): Experimental support for custom training pipelines. This overrides `make_policy_optimizer`. Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_initial_state: self.state = get_initial_state(self) else: self.state = {} # Override default policy if `get_policy_class` is provided. if get_policy_class is not None: self._policy = get_policy_class(config) if before_init: before_init(self) # Creating all workers (excluding evaluation workers). if make_workers: self.workers = make_workers(self, env_creator, self._policy, config) else: self.workers = self._make_workers(env_creator, self._policy, config, self.config["num_workers"]) self.train_pipeline = None self.optimizer = None if training_pipeline and (self.config["use_pipeline_impl"] or "RLLIB_USE_PIPELINE_IMPL" in os.environ): logger.warning("Using experimental pipeline based impl.") self.train_pipeline = training_pipeline(self.workers, config) elif make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( config["optimizer"], **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) if after_init: after_init(self) @override(Trainer) def _train(self): if self.train_pipeline: return self._train_pipeline() if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled start = time.time() while True: fetches = self.optimizer.step() if after_optimizer_step: after_optimizer_step(self, fetches) if (time.time() - start >= self.config["min_iter_time_s"] and self.optimizer.num_steps_sampled - prev_steps >= self.config["timesteps_per_iteration"]): break if collect_metrics_fn: res = collect_metrics_fn(self) else: res = self.collect_metrics() res.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) if after_train_result: after_train_result(self, res) return res def _train_pipeline(self): if before_train_step: before_train_step(self) res = next(self.train_pipeline) if after_train_result: after_train_result(self, res) return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() if self.train_pipeline: state["train_pipeline"] = self.train_pipeline.metrics.save() return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() if self.train_pipeline: self.train_pipeline.metrics.restore(state["train_pipeline"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. Arguments: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = staticmethod(with_updates) trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
def build_torch_policy(name, *, loss_fn, get_default_config=None, stats_fn=None, postprocess_fn=None, extra_action_out_fn=None, extra_grad_process_fn=None, optimizer_fn=None, before_init=None, after_init=None, action_sampler_fn=None, action_distribution_fn=None, make_model_and_action_dist=None, apply_gradients_fn=None, mixins=None, get_batch_divisibility_req=None): """Helper function for creating a torch policy at runtime. Arguments: name (str): name of the policy (e.g., "PPOTorchPolicy") loss_fn (func): function that returns a loss tensor as arguments (policy, model, dist_class, train_batch) get_default_config (func): optional function that returns the default config to merge with any overrides stats_fn (func): optional function that returns a dict of values given the policy and batch input tensors postprocess_fn (func): optional experience postprocessing function that takes the same args as Policy.postprocess_trajectory() extra_action_out_fn (func): optional function that returns a dict of extra values to include in experiences extra_grad_process_fn (func): optional function that is called after gradients are computed and returns processing info optimizer_fn (func): optional function that returns a torch optimizer given the policy and config before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor after_init (func): optional function to run at the end of policy init that takes the same arguments as the policy constructor action_sampler_fn (Optional[callable]): A callable returning a sampled action and its log-likelihood given some (obs and state) inputs. action_distribution_fn (Optional[callable]): A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). make_model_and_action_dist (func): optional func that takes the same arguments as policy init and returns a tuple of model instance and torch action distribution class. If not specified, the default model and action dist from the catalog will be used apply_gradients_fn (Optional[callable]): An optional callable that takes a grads list and applies these to the Model's parameters. mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class get_batch_divisibility_req (Optional[callable]): Optional callable that returns the divisibility requirement for sample batches. Returns: a TorchPolicy instance that uses the specified args """ original_kwargs = locals().copy() base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if before_init: before_init(self, obs_space, action_space, config) if make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: TorchPolicy::make_model_and_action_dist must " \ "return a TorchModelV2 object!" else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="torch", **self.config["model"].get("custom_options", {})) TorchPolicy.__init__( self, observation_space=obs_space, action_space=action_space, config=config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) if after_init: after_init(self, obs_space, action_space, config) @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): # Do all post-processing always with no_grad(). # Not using this here will introduce a memory leak (issue #6962). with torch.no_grad(): # Call super's postprocess_trajectory first. sample_batch = super().postprocess_trajectory( convert_to_non_torch_type(sample_batch), convert_to_non_torch_type(other_agent_batches), episode) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(TorchPolicy) def extra_grad_process(self, optimizer, loss): """Called after optimizer.zero_grad() and loss.backward() calls. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping. """ if extra_grad_process_fn: return extra_grad_process_fn(self, optimizer, loss) else: return TorchPolicy.extra_grad_process(self, optimizer, loss) @override(TorchPolicy) def apply_gradients(self, gradients): if apply_gradients_fn: apply_gradients_fn(self, gradients) else: TorchPolicy.apply_gradients(self, gradients) @override(TorchPolicy) def extra_action_out(self, input_dict, state_batches, model, action_dist): with torch.no_grad(): if extra_action_out_fn: stats_dict = extra_action_out_fn( self, input_dict, state_batches, model, action_dist) else: stats_dict = TorchPolicy.extra_action_out( self, input_dict, state_batches, model, action_dist) return convert_to_non_torch_type(stats_dict) @override(TorchPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return TorchPolicy.optimizer(self) @override(TorchPolicy) def extra_grad_info(self, train_batch): with torch.no_grad(): if stats_fn: stats_dict = stats_fn(self, train_batch) else: stats_dict = TorchPolicy.extra_grad_info(self, train_batch) return convert_to_non_torch_type(stats_dict) def with_updates(**overrides): return build_torch_policy(**dict(original_kwargs, **overrides)) policy_cls.with_updates = staticmethod(with_updates) policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls
def build_trainer( name: str, *, default_config: Optional[TrainerConfigDict] = None, validate_config: Optional[Callable[[TrainerConfigDict], None]] = None, default_policy: Optional[Type[Policy]] = None, get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[Policy]]]] = None, validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, before_init: Optional[Callable[[Trainer], None]] = None, after_init: Optional[Callable[[Trainer], None]] = None, before_evaluate_fn: Optional[Callable[[Trainer], None]] = None, mixins: Optional[List[type]] = None, execution_plan: Optional[ Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]]] = default_execution_plan ) -> Type[Trainer]: """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_policy 2. Worker setup: before_init, execution_plan 3. Post setup: after_init Args: name (str): name of the trainer (e.g., "PPO") default_config (Optional[TrainerConfigDict]): The default config dict of the algorithm, otherwise uses the Trainer default config. validate_config (Optional[Callable[[TrainerConfigDict], None]]): Optional callable that takes the config to check for correctness. It may mutate the config as needed. default_policy (Optional[Type[Policy]]): The default Policy class to use if `get_policy_class` returns None. get_policy_class (Optional[Callable[ TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable that takes a config and returns the policy class or None. If None is returned, will use `default_policy` (which must be provided then). validate_env (Optional[Callable[[EnvType, EnvContext], None]]): Optional callable to validate the generated environment (only on worker=0). before_init (Optional[Callable[[Trainer], None]]): Optional callable to run before anything is constructed inside Trainer (Workers with Policies, execution plan, etc..). Takes the Trainer instance as argument. after_init (Optional[Callable[[Trainer], None]]): Optional callable to run at the end of trainer init (after all Workers and the exec. plan have been constructed). Takes the Trainer instance as argument. before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]]]): Optional callable that sets up the distributed execution workflow. Returns: Type[Trainer]: A Trainer sub-class configured by the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy_class = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config: TrainerConfigDict, env_creator: Callable[[EnvConfigDict], EnvType]): # Validate config via custom validation function. if validate_config: validate_config(config) # No `get_policy_class` function. if get_policy_class is None: # Default_policy must be provided (unless in multi-agent mode, # where each policy can have its own default policy class. if not config["multiagent"]["policies"]: assert default_policy is not None self._policy_class = default_policy # Query the function for a class to use. else: self._policy_class = get_policy_class(config) # If None returned, use default policy (must be provided). if self._policy_class is None: assert default_policy is not None self._policy_class = default_policy if before_init: before_init(self) # Creating all workers (excluding evaluation workers). self.workers = self._make_workers( env_creator=env_creator, validate_env=validate_env, policy_class=self._policy_class, config=config, num_workers=self.config["num_workers"]) self.execution_plan = execution_plan self.train_exec_impl = execution_plan(self.workers, config) if after_init: after_init(self) @override(Trainer) def step(self): res = next(self.train_exec_impl) return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) @override(Trainer) def __getstate__(self): state = Trainer.__getstate__(self) state["train_exec_impl"] = ( self.train_exec_impl.shared_metrics.get().save()) return state @override(Trainer) def __setstate__(self, state): Trainer.__setstate__(self, state) self.train_exec_impl.shared_metrics.get().restore( state["train_exec_impl"]) @staticmethod @override(Trainer) def with_updates(**overrides) -> Type[Trainer]: """Build a copy of this trainer class with the specified overrides. Keyword Args: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. Returns: Type[Trainer]: A the Trainer sub-class using `original_kwargs` and `overrides`. Examples: >>> MyClass = SomeOtherClass.with_updates({"name": "Mine"}) >>> issubclass(MyClass, SomeOtherClass) ... False >>> issubclass(MyClass, Trainer) ... True """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
def build_trainer(name, default_policy, default_config=None, validate_config=None, get_initial_state=None, get_policy_class=None, before_init=None, make_workers=None, make_policy_optimizer=None, after_init=None, before_train_step=None, after_optimizer_step=None, after_train_result=None, collect_metrics_fn=None, before_evaluate_fn=None, mixins=None, execution_plan=None): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_initial_state, get_policy 2. Worker setup: before_init, make_workers, make_policy_optimizer 3. Post setup: after_init Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): The default config dict of the algorithm, otherwise uses the Trainer default config. validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. get_initial_state (func): optional function that returns the initial state dict given the trainer instance as an argument. The state dict must be serializable so that it can be checkpointed, and will be available as the `trainer.state` variable. get_policy_class (func): optional callback that takes a config and returns the policy class to override the default with before_init (func): optional function to run at the start of trainer init that takes the trainer instance as argument make_workers (func): override the method that creates rollout workers. This takes in (trainer, env_creator, policy, config) as args. make_policy_optimizer (func): optional function that returns a PolicyOptimizer instance given (WorkerSet, config) after_init (func): optional function to run at the end of trainer init that takes the trainer instance as argument before_train_step (func): optional callback to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (func): optional callback to run after each step() call to the policy optimizer. It takes the trainer instance and the policy gradient fetches as arguments. after_train_result (func): optional callback to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed. collect_metrics_fn (func): override the method used to collect metrics. It takes the trainer instance as argumnt. before_evaluate_fn (func): callback to run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class execution_plan (func): Experimental distributed execution API. This overrides `make_policy_optimizer`. Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_initial_state: self.state = get_initial_state(self) else: self.state = {} if get_policy_class is None: self._policy = default_policy else: self._policy = get_policy_class(config) if before_init: before_init(self) use_exec_api = (execution_plan and (self.config["use_exec_api"] or "RLLIB_EXEC_API" in os.environ)) # Creating all workers (excluding evaluation workers). if make_workers and not use_exec_api: self.workers = make_workers(self, env_creator, self._policy, config) else: self.workers = self._make_workers(env_creator, self._policy, config, self.config["num_workers"]) self.train_exec_impl = None self.optimizer = None self.execution_plan = execution_plan if use_exec_api: self.train_exec_impl = execution_plan(self.workers, config) elif make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( config["optimizer"], **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) if after_init: after_init(self) @override(Trainer) def _train(self): if self.train_exec_impl: return self._train_exec_impl() if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled start = time.time() optimizer_steps_this_iter = 0 while True: fetches = self.optimizer.step() optimizer_steps_this_iter += 1 if after_optimizer_step: after_optimizer_step(self, fetches) if (time.time() - start >= self.config["min_iter_time_s"] and self.optimizer.num_steps_sampled - prev_steps >= self.config["timesteps_per_iteration"]): break if collect_metrics_fn: res = collect_metrics_fn(self) else: res = self.collect_metrics() res.update(optimizer_steps_this_iter=optimizer_steps_this_iter, timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) if after_train_result: after_train_result(self, res) return res def _train_exec_impl(self): if before_train_step: logger.debug("Ignoring before_train_step callback") res = next(self.train_exec_impl) if after_train_result: logger.debug("Ignoring after_train_result callback") return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() if self.train_exec_impl: state["train_exec_impl"] = ( self.train_exec_impl.shared_metrics.get().save()) return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() if self.train_exec_impl: self.train_exec_impl.shared_metrics.get().restore( state["train_exec_impl"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. Arguments: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = staticmethod(with_updates) trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls
def build_eager_tf_policy( name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_learn_fetches_fn=None, extra_action_out_fn=None, validate_spaces=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, action_distribution_fn=None, mixins=None, obs_include_prev_action_reward=DEPRECATED_VALUE, get_batch_divisibility_req=None, # Deprecated args. extra_action_fetches_fn=None): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging much simpler, but has lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"framework": "tfe"} in the trainer config to have it automatically be converted to an eager policy. This has the same signature as build_tf_policy().""" base = add_mixins(Policy, mixins) if extra_action_fetches_fn is not None: deprecation_warning( old="extra_action_fetches_fn", new="extra_action_out_fn", error=False) extra_action_out_fn = extra_action_fetches_fn if obs_include_prev_action_reward != DEPRECATED_VALUE: deprecation_warning(old="obs_include_prev_action_reward", error=False) class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) # Log device and worker index. from ray.rllib.evaluation.rollout_worker import get_global_worker worker = get_global_worker() worker_idx = worker.worker_index if worker else 0 if tf.config.list_physical_devices("GPU"): logger.info( "TF-eager Policy (worker={}) running on GPU.".format( worker_idx if worker_idx > 0 else "local")) else: logger.info( "TF-eager Policy (worker={}) running on CPU.".format( worker_idx if worker_idx > 0 else "local")) self._is_training = False self._loss_initialized = False self._sess = None self._loss = loss_fn self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) self._max_seq_len = config["model"]["max_seq_len"] if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) # Lock used for locking some methods on the object-level. # This prevents possible race conditions when calling the model # first, then its value function (e.g. in a loss function), in # between of which another model call is made (e.g. to compute an # action). self._lock = threading.RLock() # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() self.exploration = self._create_exploration() self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) if before_loss_init: before_loss_init(self, observation_space, action_space, config) if optimizer_fn: optimizers = optimizer_fn(self, config) else: optimizers = tf.keras.optimizers.Adam(config["lr"]) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) # TODO: (sven) Allow tf policy to have more than 1 optimizer. # Just like torch Policy does. self._optimizer = optimizers[0] if optimizers else None self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) self._loss_initialized = True if after_init: after_init(self, observation_space, action_space, config) # Got to reset global_timestep again after fake run-throughs. self.global_timestep = 0 @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): assert tf.executing_eagerly() # Call super's postprocess_trajectory first. sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @with_lock @override(Policy) def learn_on_batch(self, postprocessed_batch): # Callback handling. learn_stats = {} self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch, result=learn_stats) if not isinstance(postprocessed_batch, SampleBatch) or \ not postprocessed_batch.zero_padded: pad_batch_to_sequences_of_same_size( postprocessed_batch, max_seq_len=self._max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, view_requirements=self.view_requirements, ) else: postprocessed_batch["seq_lens"] = postprocessed_batch.seq_lens self._is_training = True postprocessed_batch["is_training"] = True stats = self._learn_on_batch_eager(postprocessed_batch) stats.update({"custom_metrics": learn_stats}) return stats @convert_eager_inputs @convert_eager_outputs def _learn_on_batch_eager(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) self._apply_gradients(grads_and_vars) return stats @override(Policy) def compute_gradients(self, samples): pad_batch_to_sequences_of_same_size( samples, shuffle=False, max_seq_len=self._max_seq_len, batch_divisibility_req=self.batch_divisibility_req) self._is_training = True samples["is_training"] = True return self._compute_gradients_eager(samples) @convert_eager_inputs @convert_eager_outputs def _compute_gradients_eager(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) grads = [g for g, v in grads_and_vars] return grads, stats @override(Policy) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): self._is_training = False self._is_recurrent = \ state_batches is not None and state_batches != [] if not tf1.executing_eagerly(): tf1.enable_eager_execution() input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = \ tf.convert_to_tensor(prev_reward_batch) return self._compute_action_helper(input_dict, state_batches, episodes, explore, timestep) @override(Policy) def compute_actions_from_input_dict( self, input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, **kwargs ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: if not tf1.executing_eagerly(): tf1.enable_eager_execution() # Pass lazy (eager) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) # Pack internal state inputs into (separate) list. state_batches = [ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] ] return self._compute_action_helper(input_dict, state_batches, None, explore, timestep) @with_lock @convert_eager_inputs @convert_eager_outputs def _compute_action_helper(self, input_dict, state_batches, episodes, explore, timestep): explore = explore if explore is not None else \ self.config["explore"] timestep = timestep if timestep is not None else \ self.global_timestep if isinstance(timestep, tf.Tensor): timestep = int(timestep.numpy()) self._is_training = False self._state_in = state_batches or [] # Calculate RNN sequence lengths. batch_size = input_dict[SampleBatch.CUR_OBS].shape[0] seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ else None # Add default and custom fetches. extra_fetches = {} # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): if action_sampler_fn: dist_inputs = None state_out = [] actions, logp = action_sampler_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, episodes=episodes) else: # Exploration hook before each forward pass. self.exploration.before_compute_actions( timestep=timestep, explore=explore) if action_distribution_fn: # Try new action_distribution_fn signature, supporting # state_batches and seq_lens. try: dist_inputs, self.dist_class, state_out = \ action_distribution_fn( self, self.model, input_dict=input_dict, state_batches=state_batches, seq_lens=seq_lens, explore=explore, timestep=timestep, is_training=False) # Trying the old way (to stay backward compatible). # TODO: Remove in future. except TypeError as e: if "positional argument" in e.args[0] or \ "unexpected keyword argument" in e.args[0]: dist_inputs, self.dist_class, state_out = \ action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, is_training=False) else: raise e elif isinstance(self.model, tf.keras.Model): input_dict = SampleBatch(input_dict, seq_lens=seq_lens) self._lazy_tensor_dict(input_dict) dist_inputs, state_out, extra_fetches = \ self.model(input_dict) else: dist_inputs, state_out = self.model( input_dict, state_batches, seq_lens) action_dist = self.dist_class(dist_inputs, self.model) # Get the exploration action from the forward results. actions, logp = self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore) # Action-logp and action-prob. if logp is not None: extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) extra_fetches[SampleBatch.ACTION_LOGP] = logp # Action-dist inputs. if dist_inputs is not None: extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs # Custom extra fetches. if extra_action_out_fn: extra_fetches.update(extra_action_out_fn(self)) # Update our global timestep by the batch size. self.global_timestep += int(batch_size) return actions, state_out, extra_fetches @with_lock @override(Policy) def compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None): if action_sampler_fn and action_distribution_fn is None: raise ValueError("Cannot compute log-prob/likelihood w/o an " "`action_distribution_fn` and a provided " "`action_sampler_fn`!") seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = \ tf.convert_to_tensor(prev_reward_batch) # Exploration hook before each forward pass. self.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if action_distribution_fn: dist_inputs, dist_class, _ = action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=False, is_training=False) # Default log-likelihood calculation. else: dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) dist_class = self.dist_class action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(actions) return log_likelihoods @override(Policy) def apply_gradients(self, gradients): self._apply_gradients( zip([(tf.convert_to_tensor(g) if g is not None else None) for g in gradients], self.model.trainable_variables())) @override(Policy) def get_exploration_info(self): return _convert_to_numpy(self.exploration.get_info()) @override(Policy) def get_weights(self, as_dict=False): variables = self.variables() if as_dict: return {v.name: v.numpy() for v in variables} return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): variables = self.variables() assert len(weights) == len(variables), (len(weights), len(variables)) for v, w in zip(variables, weights): v.assign(w) @override(Policy) def get_state(self): state = {"_state": super().get_state()} if self._optimizer and \ len(self._optimizer.variables()) > 0: state["_optimizer_variables"] = \ self._optimizer.variables() return state @override(Policy) def set_state(self, state): state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) if optimizer_vars and self._optimizer.variables(): logger.warning( "Cannot restore an optimizer's state for tf eager! Keras " "is not able to save the v1.x optimizers (from " "tf.compat.v1.train) since they aren't compatible with " "checkpoints.") for opt_var, value in zip(self._optimizer.variables(), optimizer_vars): opt_var.assign(value) # Then the Policy's (NN) weights. super().set_state(state["_state"]) def variables(self): """Return the list of all savable variables for this policy.""" if isinstance(self.model, tf.keras.Model): return self.model.variables else: return self.model.variables() @override(Policy) def is_recurrent(self): return self._is_recurrent @override(Policy) def num_state_tensors(self): return len(self._state_inputs) @override(Policy) def get_initial_state(self): if hasattr(self, "model"): return self.model.get_initial_state() return [] def get_session(self): return None # None implies eager def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " "rllib.utils.tf_ops.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): return self._loss_initialized @override(Policy) def export_model(self, export_dir): pass @override(Policy) def export_checkpoint(self, export_dir): pass def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) def _apply_gradients(self, grads_and_vars): if apply_gradients_fn: apply_gradients_fn(self, self._optimizer, grads_and_vars) else: self._optimizer.apply_gradients( [(g, v) for g, v in grads_and_vars if g is not None]) @with_lock def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" with tf.GradientTape(persistent=gradients_fn is not None) as tape: loss = loss_fn(self, self.model, self.dist_class, samples) if isinstance(self.model, tf.keras.Model): variables = self.model.trainable_variables else: variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper: def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats def _stats(self, outputs, samples, grads): fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { k: v for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: fetches.update( {k: v for k, v in extra_learn_fetches_fn(self).items()}) if grad_stats_fn: fetches.update({ k: v for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch): # TODO: (sven): Keep for a while to ensure backward compatibility. if not isinstance(postprocessed_batch, SampleBatch): postprocessed_batch = SampleBatch(postprocessed_batch) postprocessed_batch.set_get_interceptor(_convert_to_tf) return postprocessed_batch @classmethod def with_tracing(cls): return traced_eager_policy(cls) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls
def build_eager_tf_policy(name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_learn_fetches_fn=None, extra_action_fetches_fn=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, log_likelihood_fn=None, mixins=None, obs_include_prev_action_reward=True, get_batch_divisibility_req=None): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging much simpler, but has lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"eager": true} in the trainer config to have it automatically be converted to an eager policy. This has the same signature as build_tf_policy().""" base = add_mixins(Policy, mixins) class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = "tf" Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None if get_default_config: config = dict(get_default_config(), **config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn: if not make_model: raise ValueError("`make_model` is required if " "`action_sampler_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework="tf", ) self._state_in = [ tf.convert_to_tensor(np.array([s])) for s in self.model.get_initial_state() ] input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(np.array([observation_space.sample()])), SampleBatch.PREV_ACTIONS: tf.convert_to_tensor([_flatten_action(action_space.sample())]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]), } self.model(input_dict, self._state_in, tf.convert_to_tensor([1])) if before_loss_init: before_loss_init(self, observation_space, action_space, config) self._initialize_loss_with_dummy_batch() self._loss_initialized = True if optimizer_fn: self._optimizer = optimizer_fn(self, config) else: self._optimizer = tf.train.AdamOptimizer(config["lr"]) if after_init: after_init(self, observation_space, action_space, config) @override(Policy) def postprocess_trajectory(self, samples, other_agent_batches=None, episode=None): assert tf.executing_eagerly() if postprocess_fn: return postprocess_fn(self, samples, other_agent_batches, episode) else: return samples @override(Policy) @convert_eager_inputs @convert_eager_outputs def learn_on_batch(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) self._apply_gradients(grads_and_vars) return stats @override(Policy) @convert_eager_inputs @convert_eager_outputs def compute_gradients(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) grads = [g for g, v in grads_and_vars] return grads, stats @override(Policy) @convert_eager_inputs @convert_eager_outputs def compute_actions(self, obs_batch, state_batches, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): explore = explore if explore is not None else \ self.config["explore"] # TODO: remove python side effect to cull sources of bugs. self._is_training = False self._state_in = state_batches if tf.executing_eagerly(): n = len(obs_batch) else: n = obs_batch.shape[0] seq_lens = tf.ones(n, dtype=tf.int32) input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } if obs_include_prev_action_reward: input_dict.update({ SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(prev_action_batch), SampleBatch.PREV_REWARDS: tf.convert_to_tensor(prev_reward_batch), }) # Custom sampler fn given (which may handle self.exploration). if action_sampler_fn is not None: state_out = [] action, logp = action_sampler_fn( self, self.model, input_dict, self.observation_space, self.action_space, explore, self.config, timestep=timestep if timestep is not None else self.global_timestep) # Use Exploration object. else: with tf.variable_creator_scope(_disallow_var_creation): model_out, state_out = self.model(input_dict, state_batches, seq_lens) action, logp = self.exploration.get_exploration_action( model_out, self.dist_class, self.model, timestep=timestep if timestep is not None else self.global_timestep, explore=explore) extra_fetches = {} if logp is not None: extra_fetches.update({ ACTION_PROB: tf.exp(logp), ACTION_LOGP: logp, }) if extra_action_fetches_fn: extra_fetches.update(extra_action_fetches_fn(self)) # Increase our global sampling timestep counter by 1. self.global_timestep += 1 return action, state_out, extra_fetches @override(Policy) def compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None): seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } if obs_include_prev_action_reward: input_dict.update({ SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(prev_action_batch), SampleBatch.PREV_REWARDS: tf.convert_to_tensor(prev_reward_batch), }) # Custom log_likelihood function given. if log_likelihood_fn: log_likelihoods = log_likelihood_fn(self, self.model, actions, input_dict, self.observation_space, self.action_space, self.config) # Default log-likelihood calculation. else: dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) action_dist = self.dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(actions) return log_likelihoods @override(Policy) def apply_gradients(self, gradients): self._apply_gradients( zip([(tf.convert_to_tensor(g) if g is not None else None) for g in gradients], self.model.trainable_variables())) @override(Policy) def get_exploration_info(self): return _convert_to_numpy(self.exploration.get_info()) @override(Policy) def get_weights(self): variables = self.variables() return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): variables = self.variables() assert len(weights) == len(variables), (len(weights), len(variables)) for v, w in zip(variables, weights): v.assign(w) def variables(self): """Return the list of all savable variables for this policy.""" return self.model.variables() @override(Policy) def is_recurrent(self): return len(self._state_in) > 0 @override(Policy) def num_state_tensors(self): return len(self._state_in) @override(Policy) def get_initial_state(self): return self.model.get_initial_state() def get_session(self): return None # None implies eager def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " "rllib.utils.tf_ops.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): return self._loss_initialized @override(Policy) def export_model(self, export_dir): pass @override(Policy) def export_checkpoint(self, export_dir): pass def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) def _apply_gradients(self, grads_and_vars): if apply_gradients_fn: apply_gradients_fn(self, self._optimizer, grads_and_vars) else: self._optimizer.apply_gradients(grads_and_vars) def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" self._is_training = True with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state-in properly state_in = [] for i in range(self.num_state_tensors()): state_in.append(samples["state_in_{}".format(i)]) self._state_in = state_in self._seq_lens = None if len(state_in) > 0: self._seq_lens = tf.ones( samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32) samples["seq_lens"] = self._seq_lens model_out, _ = self.model(samples, self._state_in, self._seq_lens) loss = loss_fn(self, self.model, self.dist_class, samples) variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper: def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats def _stats(self, outputs, samples, grads): fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { k: v for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: fetches.update( {k: v for k, v in extra_learn_fetches_fn(self).items()}) if grad_stats_fn: fetches.update({ k: v for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches def _initialize_loss_with_dummy_batch(self): # Dummy forward pass to initialize any policy attributes, etc. action_dtype, action_shape = ModelCatalog.get_action_shape( self.action_space) dummy_batch = { SampleBatch.CUR_OBS: np.array([self.observation_space.sample()]), SampleBatch.NEXT_OBS: np.array([self.observation_space.sample()]), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: tf.nest.map_structure(lambda c: np.array([c]), self.action_space.sample()), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: dummy_batch[SampleBatch.ACTIONS], SampleBatch.PREV_REWARDS: dummy_batch[SampleBatch.REWARDS], }) for i, h in enumerate(self._state_in): dummy_batch["state_in_{}".format(i)] = h dummy_batch["state_out_{}".format(i)] = h if self._state_in: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) # Convert everything to tensors. dummy_batch = tf.nest.map_structure(tf.convert_to_tensor, dummy_batch) # for IMPALA which expects a certain sample batch size. def tile_to(tensor, n): return tf.tile(tensor, [n] + [1 for _ in tensor.shape.as_list()[1:]]) if get_batch_divisibility_req: dummy_batch = tf.nest.map_structure( lambda c: tile_to(c, get_batch_divisibility_req(self)), dummy_batch) # Execute a forward pass to get self.action_dist etc initialized, # and also obtain the extra action fetches _, _, fetches = self.compute_actions( dummy_batch[SampleBatch.CUR_OBS], self._state_in, dummy_batch.get(SampleBatch.PREV_ACTIONS), dummy_batch.get(SampleBatch.PREV_REWARDS)) dummy_batch.update(fetches) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model.from_batch(dummy_batch) postprocessed_batch = tf.nest.map_structure( lambda c: tf.convert_to_tensor(c), postprocessed_batch.data) loss_fn(self, self.model, self.dist_class, postprocessed_batch) if stats_fn: stats_fn(self, postprocessed_batch) @classmethod def with_tracing(cls): return traced_eager_policy(cls) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls
def build_eager_tf_policy(name, loss_fn, get_default_config=None, postprocess_fn=None, stats_fn=None, optimizer_fn=None, gradients_fn=None, apply_gradients_fn=None, grad_stats_fn=None, extra_learn_fetches_fn=None, extra_action_fetches_fn=None, validate_spaces=None, before_init=None, before_loss_init=None, after_init=None, make_model=None, action_sampler_fn=None, action_distribution_fn=None, mixins=None, view_requirements_fn=None, obs_include_prev_action_reward=True, get_batch_divisibility_req=None): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging much simpler, but has lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"framework": "tfe"} in the trainer config to have it automatically be converted to an eager policy. This has the same signature as build_tf_policy().""" base = add_mixins(Policy, mixins) class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None self._loss = loss_fn self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) self._max_seq_len = config["model"]["max_seq_len"] if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) # Auto-update model's inference view requirements, if recurrent. self._update_model_inference_view_requirements_from_init_state() self.exploration = self._create_exploration() self._state_in = [ tf.convert_to_tensor([s]) for s in self.model.get_initial_state() ] # Update this Policy's ViewRequirements (if function given). if callable(view_requirements_fn): self.view_requirements.update(view_requirements_fn(self)) # Combine view_requirements for Model and Policy. self.view_requirements.update( self.model.inference_view_requirements) if before_loss_init: before_loss_init(self, observation_space, action_space, config) self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) self._loss_initialized = True if optimizer_fn: optimizers = optimizer_fn(self, config) else: optimizers = tf.keras.optimizers.Adam(config["lr"]) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) # TODO: (sven) Allow tf policy to have more than 1 optimizer. # Just like torch Policy does. self._optimizer = optimizers[0] if optimizers else None if after_init: after_init(self, observation_space, action_space, config) # Got to reset global_timestep again after this fake run-through. self.global_timestep = 0 @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): assert tf.executing_eagerly() # Call super's postprocess_trajectory first. sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @override(Policy) def learn_on_batch(self, postprocessed_batch): # Callback handling. self.callbacks.on_learn_on_batch(policy=self, train_batch=postprocessed_batch) # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( postprocessed_batch, shuffle=False, max_seq_len=self._max_seq_len, batch_divisibility_req=self.batch_divisibility_req) return self._learn_on_batch_eager(postprocessed_batch) @convert_eager_inputs @convert_eager_outputs def _learn_on_batch_eager(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) self._apply_gradients(grads_and_vars) return stats @override(Policy) def compute_gradients(self, samples): # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( samples, shuffle=False, max_seq_len=self._max_seq_len, batch_divisibility_req=self.batch_divisibility_req) return self._compute_gradients_eager(samples) @convert_eager_inputs @convert_eager_outputs def _compute_gradients_eager(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) grads = [g for g, v in grads_and_vars] return grads, stats @override(Policy) @convert_eager_inputs @convert_eager_outputs def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): explore = explore if explore is not None else \ self.config["explore"] timestep = timestep if timestep is not None else \ self.global_timestep # TODO: remove python side effect to cull sources of bugs. self._is_training = False self._state_in = state_batches if not tf1.executing_eagerly(): tf1.enable_eager_execution() input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } batch_size = input_dict[SampleBatch.CUR_OBS].shape[0] seq_lens = tf.ones(batch_size, dtype=tf.int32) if obs_include_prev_action_reward: if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ tf.convert_to_tensor(prev_action_batch) if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = \ tf.convert_to_tensor(prev_reward_batch) # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): if action_sampler_fn: dist_inputs = None state_out = [] actions, logp = self.action_sampler_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, episodes=episodes) else: # Exploration hook before each forward pass. self.exploration.before_compute_actions(timestep=timestep, explore=explore) if action_distribution_fn: dist_inputs, dist_class, state_out = \ action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, is_training=False) else: dist_class = self.dist_class dist_inputs, state_out = self.model( input_dict, state_batches, seq_lens) action_dist = dist_class(dist_inputs, self.model) # Get the exploration action from the forward results. actions, logp = self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore) # Add default and custom fetches. extra_fetches = {} # Action-logp and action-prob. if logp is not None: extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) extra_fetches[SampleBatch.ACTION_LOGP] = logp # Action-dist inputs. if dist_inputs is not None: extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs # Custom extra fetches. if extra_action_fetches_fn: extra_fetches.update(extra_action_fetches_fn(self)) # Update our global timestep by the batch size. self.global_timestep += int(batch_size) return actions, state_out, extra_fetches @override(Policy) def compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None): if action_sampler_fn and action_distribution_fn is None: raise ValueError("Cannot compute log-prob/likelihood w/o an " "`action_distribution_fn` and a provided " "`action_sampler_fn`!") seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } if obs_include_prev_action_reward: input_dict.update({ SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(prev_action_batch), SampleBatch.PREV_REWARDS: tf.convert_to_tensor(prev_reward_batch), }) # Exploration hook before each forward pass. self.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if action_distribution_fn: dist_inputs, dist_class, _ = action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=False, is_training=False) action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(actions) # Default log-likelihood calculation. else: dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) dist_class = self.dist_class action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(actions) return log_likelihoods @override(Policy) def apply_gradients(self, gradients): self._apply_gradients( zip([(tf.convert_to_tensor(g) if g is not None else None) for g in gradients], self.model.trainable_variables())) @override(Policy) def get_exploration_info(self): return _convert_to_numpy(self.exploration.get_info()) @override(Policy) def get_weights(self, as_dict=False): variables = self.variables() if as_dict: return {v.name: v.numpy() for v in variables} return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): variables = self.variables() assert len(weights) == len(variables), (len(weights), len(variables)) for v, w in zip(variables, weights): v.assign(w) @override(Policy) def get_state(self): state = {"_state": super().get_state()} state["_optimizer_variables"] = self._optimizer.variables() return state @override(Policy) def set_state(self, state): state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) if optimizer_vars and self._optimizer.variables(): logger.warning( "Cannot restore an optimizer's state for tf eager! Keras " "is not able to save the v1.x optimizers (from " "tf.compat.v1.train) since they aren't compatible with " "checkpoints.") for opt_var, value in zip(self._optimizer.variables(), optimizer_vars): opt_var.assign(value) # Then the Policy's (NN) weights. super().set_state(state["_state"]) def variables(self): """Return the list of all savable variables for this policy.""" return self.model.variables() @override(Policy) def is_recurrent(self): return len(self._state_in) > 0 @override(Policy) def num_state_tensors(self): return len(self._state_in) @override(Policy) def get_initial_state(self): return self.model.get_initial_state() def get_session(self): return None # None implies eager def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " "rllib.utils.tf_ops.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): return self._loss_initialized @override(Policy) def export_model(self, export_dir): pass @override(Policy) def export_checkpoint(self, export_dir): pass def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) def _apply_gradients(self, grads_and_vars): if apply_gradients_fn: apply_gradients_fn(self, self._optimizer, grads_and_vars) else: self._optimizer.apply_gradients([(g, v) for g, v in grads_and_vars if g is not None]) def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" self._is_training = True with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state-in properly state_in = [] for i in range(self.num_state_tensors()): state_in.append(samples["state_in_{}".format(i)]) self._state_in = state_in model_out, _ = self.model(samples, self._state_in, samples.get("seq_lens")) loss = loss_fn(self, self.model, self.dist_class, samples) variables = self.model.trainable_variables() if gradients_fn: class OptimizerWrapper: def __init__(self, tape): self.tape = tape def compute_gradients(self, loss, var_list): return list( zip(self.tape.gradient(loss, var_list), var_list)) grads_and_vars = gradients_fn(self, OptimizerWrapper(tape), loss) else: grads_and_vars = list( zip(tape.gradient(loss, variables), variables)) if log_once("grad_vars"): for _, v in grads_and_vars: logger.info("Optimizing variable {}".format(v.name)) grads = [g for g, v in grads_and_vars] stats = self._stats(self, samples, grads) return grads_and_vars, stats def _stats(self, outputs, samples, grads): fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { k: v for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: fetches.update( {k: v for k, v in extra_learn_fetches_fn(self).items()}) if grad_stats_fn: fetches.update({ k: v for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches def _lazy_tensor_dict(self, postprocessed_batch): train_batch = UsageTrackingDict(postprocessed_batch) train_batch.set_get_interceptor(_convert_to_tf) return train_batch def _lazy_numpy_dict(self, postprocessed_batch): train_batch = UsageTrackingDict(postprocessed_batch) train_batch.set_get_interceptor(convert_to_non_tf_type) return train_batch @classmethod def with_tracing(cls): return traced_eager_policy(cls) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls
def build_trainer(name, default_policy, default_config=None, validate_config=None, get_initial_state=None, get_policy_class=None, before_init=None, make_workers=None, make_policy_optimizer=None, after_init=None, before_train_step=None, after_optimizer_step=None, after_train_result=None, collect_metrics_fn=None, before_evaluate_fn=None, mixins=None): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: 1. Config setup: validate_config, get_initial_state, get_policy 2. Worker setup: before_init, make_workers, make_policy_optimizer 3. Post setup: after_init Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): the default config dict of the algorithm, otherwises uses the Trainer default config validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. get_initial_state (func): optional function that returns the initial state dict given the trainer instance as an argument. The state dict must be serializable so that it can be checkpointed, and will be available as the `trainer.state` variable. get_policy_class (func): optional callback that takes a config and returns the policy class to override the default with before_init (func): optional function to run at the start of trainer init that takes the trainer instance as argument make_workers (func): override the method that creates rollout workers. This takes in (trainer, env_creator, policy, config) as args. make_policy_optimizer (func): optional function that returns a PolicyOptimizer instance given (WorkerSet, config) after_init (func): optional function to run at the end of trainer init that takes the trainer instance as argument before_train_step (func): optional callback to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (func): optional callback to run after each step() call to the policy optimizer. It takes the trainer instance and the policy gradient fetches as arguments. after_train_result (func): optional callback to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed. collect_metrics_fn (func): override the method used to collect metrics. It takes the trainer instance as argumnt. before_evaluate_fn (func): callback to run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_initial_state: self.state = get_initial_state(self) else: self.state = {} if get_policy_class is None: policy = default_policy else: policy = get_policy_class(config) if before_init: before_init(self) if make_workers: self.workers = make_workers(self, env_creator, policy, config) else: self.workers = self._make_workers(env_creator, policy, config, self.config["num_workers"]) if make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( config["optimizer"], **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) # self.optimizer: <Override_ray.sync_replay_optimizer.SyncReplayOptimizer object at 0x7f7424799d90> if after_init: after_init(self) @override(Trainer) def _train(self, attention_score_dic=None): """ :param attention_score_dic: Call from trainable :return: """ if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled start = time.time() while True: ''' The network is trained here. ''' fetches = self.optimizer.step(attention_score_dic) if after_optimizer_step: ''' Judge to update target network; after_optimizer_step=update_target_if_needed ''' after_optimizer_step(self, fetches) if (time.time() - start >= self.config["min_iter_time_s"] and self.optimizer.num_steps_sampled - prev_steps >= self.config["timesteps_per_iteration"]): ''' Judge to finish the iteration ''' break ''' Refine the results ''' if collect_metrics_fn: # collect_metrics_fn=collect_metrics res = collect_metrics_fn(self) else: res = self.collect_metrics() res.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) if after_train_result: # after_train_result=add_trainer_metrics after_train_result(self, res) return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() @staticmethod def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. Arguments: overrides (dict): use this to override any of the arguments originally passed to build_trainer() for this policy. """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = with_updates trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls