示例#1
0
def record_env_wrapper(env, record_env, log_dir, policy_config):
    if record_env:
        path_ = record_env if isinstance(record_env, str) else log_dir
        # Relative path: Add logdir here, otherwise, this would
        # not work for non-local workers.
        if not os.path.isabs(path_):
            path_ = os.path.join(log_dir, path_)
        print(f"Setting the path for recording to {path_}")
        wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
            else wrappers.Monitor
        wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
        env = wrapper_cls(
            env,
            path_,
            resume=True,
            force=True,
            video_callable=lambda _: True,
            mode="evaluation"
            if policy_config["in_evaluation"] else "training")
    return env
示例#2
0
def build_trainer(
    name: str,
    *,
    default_config: Optional[TrainerConfigDict] = None,
    validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
    default_policy: Optional[Type[Policy]] = None,
    get_policy_class: Optional[Callable[[TrainerConfigDict],
                                        Optional[Type[Policy]]]] = None,
    validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
    before_init: Optional[Callable[[Trainer], None]] = None,
    after_init: Optional[Callable[[Trainer], None]] = None,
    before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
    mixins: Optional[List[type]] = None,
    execution_plan: Optional[
        Callable[[WorkerSet, TrainerConfigDict],
                 Iterable[ResultDict]]] = default_execution_plan
) -> Type[Trainer]:
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_policy
        2. Worker setup: before_init, execution_plan
        3. Post setup: after_init

    Args:
        name (str): name of the trainer (e.g., "PPO")
        default_config (Optional[TrainerConfigDict]): The default config dict
            of the algorithm, otherwise uses the Trainer default config.
        validate_config (Optional[Callable[[TrainerConfigDict], None]]):
            Optional callable that takes the config to check for correctness.
            It may mutate the config as needed.
        default_policy (Optional[Type[Policy]]): The default Policy class to
            use if `get_policy_class` returns None.
        get_policy_class (Optional[Callable[
            TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable
            that takes a config and returns the policy class or None. If None
            is returned, will use `default_policy` (which must be provided
            then).
        validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
            Optional callable to validate the generated environment (only
            on worker=0).
        before_init (Optional[Callable[[Trainer], None]]): Optional callable to
            run before anything is constructed inside Trainer (Workers with
            Policies, execution plan, etc..). Takes the Trainer instance as
            argument.
        after_init (Optional[Callable[[Trainer], None]]): Optional callable to
            run at the end of trainer init (after all Workers and the exec.
            plan have been constructed). Takes the Trainer instance as
            argument.
        before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to
            run before evaluation. This takes the trainer instance as argument.
        mixins (list): list of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class.
        execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict],
            Iterable[ResultDict]]]): Optional callable that sets up the
            distributed execution workflow.

    Returns:
        Type[Trainer]: A Trainer sub-class configured by the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy_class = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config: TrainerConfigDict,
                  env_creator: Callable[[EnvConfigDict], EnvType]):

            # No `get_policy_class` function.
            if get_policy_class is None:
                # Default_policy must be provided (unless in multi-agent mode,
                # where each policy can have its own default policy class.
                if not config["multiagent"]["policies"]:
                    assert default_policy is not None
                self._policy_class = default_policy
            # Query the function for a class to use.
            else:
                self._policy_class = get_policy_class(config)
                # If None returned, use default policy (must be provided).
                if self._policy_class is None:
                    assert default_policy is not None
                    self._policy_class = default_policy

            if before_init:
                before_init(self)

            # Creating all workers (excluding evaluation workers).
            self.workers = self._make_workers(
                env_creator=env_creator,
                validate_env=validate_env,
                policy_class=self._policy_class,
                config=config,
                num_workers=self.config["num_workers"])
            self.execution_plan = execution_plan
            self.train_exec_impl = execution_plan(self.workers, config)

            if after_init:
                after_init(self)

        @override(Trainer)
        def step(self):
            # self._iteration gets incremented after this function returns,
            # meaning that e. g. the first time this function is called,
            # self._iteration will be 0.
            evaluate_this_iter = \
                self.config["evaluation_interval"] and \
                (self._iteration + 1) % self.config["evaluation_interval"] == 0

            # No evaluation necessary.
            if not evaluate_this_iter:
                res = next(self.train_exec_impl)
            # We have to evaluate in this training iteration.
            else:
                # No parallelism.
                if not self.config["evaluation_parallel_to_training"]:
                    res = next(self.train_exec_impl)

                # Kick off evaluation-loop (and parallel train() call,
                # if requested).
                # Parallel eval + training.
                if self.config["evaluation_parallel_to_training"]:
                    with concurrent.futures.ThreadPoolExecutor() as executor:
                        eval_future = executor.submit(self.evaluate)
                        res = next(self.train_exec_impl)
                        evaluation_metrics = eval_future.result()
                # Sequential: train (already done above), then eval.
                else:
                    evaluation_metrics = self.evaluate()

                assert isinstance(evaluation_metrics, dict), \
                    "_evaluate() needs to return a dict."
                res.update(evaluation_metrics)

            # Check `env_task_fn` for possible update of the env's task.
            if self.config["env_task_fn"] is not None:
                if not callable(self.config["env_task_fn"]):
                    raise ValueError(
                        "`env_task_fn` must be None or a callable taking "
                        "[train_results, env, env_ctx] as args!")

                def fn(env, env_context, task_fn):
                    new_task = task_fn(res, env, env_context)
                    cur_task = env.get_task()
                    if cur_task != new_task:
                        env.set_task(new_task)

                fn = partial(fn, task_fn=self.config["env_task_fn"])
                self.workers.foreach_env_with_context(fn)

            return res

        @staticmethod
        @override(Trainer)
        def _validate_config(config: PartialTrainerConfigDict,
                             trainer_obj_or_none: Optional["Trainer"] = None):
            # Call super (Trainer) validation method first.
            Trainer._validate_config(config, trainer_obj_or_none)
            # Then call user defined one, if any.
            if validate_config is not None:
                validate_config(config)

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        @override(Trainer)
        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["train_exec_impl"] = (
                self.train_exec_impl.shared_metrics.get().save())
            return state

        @override(Trainer)
        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.train_exec_impl.shared_metrics.get().restore(
                state["train_exec_impl"])

        @staticmethod
        @override(Trainer)
        def with_updates(**overrides) -> Type[Trainer]:
            """Build a copy of this trainer class with the specified overrides.

            Keyword Args:
                overrides (dict): use this to override any of the arguments
                    originally passed to build_trainer() for this policy.

            Returns:
                Type[Trainer]: A the Trainer sub-class using `original_kwargs`
                    and `overrides`.

            Examples:
                >>> MyClass = SomeOtherClass.with_updates({"name": "Mine"})
                >>> issubclass(MyClass, SomeOtherClass)
                ... False
                >>> issubclass(MyClass, Trainer)
                ... True
            """
            return build_trainer(**dict(original_kwargs, **overrides))

        def __repr__(self):
            return self._name

    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#3
0
def build_trainer(
    name: str,
    *,
    default_config: Optional[TrainerConfigDict] = None,
    validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
    default_policy: Optional[Type[Policy]] = None,
    get_policy_class: Optional[Callable[[TrainerConfigDict],
                                        Optional[Type[Policy]]]] = None,
    validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
    before_init: Optional[Callable[[Trainer], None]] = None,
    after_init: Optional[Callable[[Trainer], None]] = None,
    before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
    mixins: Optional[List[type]] = None,
    execution_plan: Optional[
        Union[Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]],
              Callable[[Trainer, WorkerSet, TrainerConfigDict],
                       Iterable[ResultDict]]]] = None,
    allow_unknown_configs: bool = False,
    allow_unknown_subkeys: Optional[List[str]] = None,
    override_all_subkeys_if_type_changes: Optional[List[str]] = None,
) -> Type[Trainer]:
    """Helper function for defining a custom Trainer class.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_policy.
        2. Worker setup: before_init, execution_plan.
        3. Post setup: after_init.

    Args:
        name: name of the trainer (e.g., "PPO")
        default_config: The default config dict of the algorithm,
            otherwise uses the Trainer default config.
        validate_config: Optional callable that takes the config to check
            for correctness. It may mutate the config as needed.
        default_policy: The default Policy class to use if `get_policy_class`
            returns None.
        get_policy_class: Optional callable that takes a config and returns
            the policy class or None. If None is returned, will use
            `default_policy` (which must be provided then).
        validate_env: Optional callable to validate the generated environment
            (only on worker=0).
        before_init: Optional callable to run before anything is constructed
            inside Trainer (Workers with Policies, execution plan, etc..).
            Takes the Trainer instance as argument.
        after_init: Optional callable to run at the end of trainer init
            (after all Workers and the exec. plan have been constructed).
            Takes the Trainer instance as argument.
        before_evaluate_fn: Callback to run before evaluation. This takes
            the trainer instance as argument.
        mixins: List of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class.
        execution_plan: Optional callable that sets up the
            distributed execution workflow.
        allow_unknown_configs: Whether to allow unknown top-level config keys.
        allow_unknown_subkeys: List of top-level keys
            with value=dict, for which new sub-keys are allowed to be added to
            the value dict. Appends to Trainer class defaults.
        override_all_subkeys_if_type_changes: List of top level keys with
            value=dict, for which we always override the entire value (dict),
            iff the "type" key in that value dict changes. Appends to Trainer
            class defaults.

    Returns:
        A Trainer sub-class configured by the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy_class = default_policy

        def __init__(self,
                     config: TrainerConfigDict = None,
                     env: Union[str, EnvType, None] = None,
                     logger_creator: Callable[[], Logger] = None,
                     remote_checkpoint_dir: Optional[str] = None,
                     sync_function_tpl: Optional[str] = None):
            Trainer.__init__(self, config, env, logger_creator,
                             remote_checkpoint_dir, sync_function_tpl)

        @override(base)
        def setup(self, config: PartialTrainerConfigDict):
            if allow_unknown_subkeys is not None:
                self._allow_unknown_subkeys += allow_unknown_subkeys
            self._allow_unknown_configs = allow_unknown_configs
            if override_all_subkeys_if_type_changes is not None:
                self._override_all_subkeys_if_type_changes += \
                    override_all_subkeys_if_type_changes
            Trainer.setup(self, config)

        def _init(self, config: TrainerConfigDict,
                  env_creator: Callable[[EnvConfigDict], EnvType]):

            # No `get_policy_class` function.
            if get_policy_class is None:
                # Default_policy must be provided (unless in multi-agent mode,
                # where each policy can have its own default policy class).
                if not config["multiagent"]["policies"]:
                    assert default_policy is not None
            # Query the function for a class to use.
            else:
                self._policy_class = get_policy_class(config)
                # If None returned, use default policy (must be provided).
                if self._policy_class is None:
                    assert default_policy is not None
                    self._policy_class = default_policy

            if before_init:
                before_init(self)

            # Creating all workers (excluding evaluation workers).
            self.workers = self._make_workers(
                env_creator=env_creator,
                validate_env=validate_env,
                policy_class=self._policy_class,
                config=config,
                num_workers=self.config["num_workers"])

            self.train_exec_impl = self.execution_plan(
                self.workers, config, **self._kwargs_for_execution_plan())

            if after_init:
                after_init(self)

        @override(Trainer)
        def validate_config(self, config: PartialTrainerConfigDict):
            # Call super (Trainer) validation method first.
            Trainer.validate_config(self, config)
            # Then call user defined one, if any.
            if validate_config is not None:
                validate_config(config)

        @staticmethod
        @override(Trainer)
        def execution_plan(workers, config, **kwargs):
            # `execution_plan` is provided, use it inside
            # `self.execution_plan()`.
            if execution_plan is not None:
                return execution_plan(workers, config, **kwargs)
            # If `execution_plan` is not provided (None), the Trainer will use
            # it's already existing default `execution_plan()` static method
            # instead.
            else:
                return Trainer.execution_plan(workers, config, **kwargs)

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        @staticmethod
        @override(Trainer)
        def with_updates(**overrides) -> Type[Trainer]:
            """Build a copy of this trainer class with the specified overrides.

            Keyword Args:
                overrides (dict): use this to override any of the arguments
                    originally passed to build_trainer() for this policy.

            Returns:
                Type[Trainer]: A the Trainer sub-class using `original_kwargs`
                    and `overrides`.

            Examples:
                >>> from ray.rllib.agents.ppo import PPOTrainer
                >>> MyPPOClass = PPOTrainer.with_updates({"name": "MyPPO"})
                >>> issubclass(MyPPOClass, PPOTrainer)
                False
                >>> issubclass(MyPPOClass, Trainer)
                True
                >>> trainer = MyPPOClass()
                >>> print(trainer)
                MyPPO
            """
            return build_trainer(**dict(original_kwargs, **overrides))

        def __repr__(self):
            return self._name

    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#4
0
def build_trainer(
    name: str,
    default_policy: Optional[Policy],
    *,
    default_config: TrainerConfigDict = None,
    validate_config: Callable[[TrainerConfigDict], None] = None,
    get_policy_class: Callable[[TrainerConfigDict], Policy] = None,
    before_init: Callable[[Trainer], None] = None,
    after_init: Callable[[Trainer], None] = None,
    before_evaluate_fn: Callable[[Trainer], None] = None,
    mixins: List[type] = None,
    execution_plan: Callable[[WorkerSet, TrainerConfigDict],
                             Iterable[ResultDict]] = default_execution_plan):
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_policy
        2. Worker setup: before_init, execution_plan
        3. Post setup: after_init

    Arguments:
        name (str): name of the trainer (e.g., "PPO")
        default_policy (cls): the default Policy class to use
        default_config (dict): The default config dict of the algorithm,
            otherwise uses the Trainer default config.
        validate_config (Optional[callable]): Optional callable that takes the
            config to check for correctness. It may mutate the config as
            needed.
        get_policy_class (Optional[callable]): Optional callable that takes a
            config and returns the policy class to override the default with.
        before_init (Optional[callable]): Optional callable to run at the start
            of trainer init that takes the trainer instance as argument.
        after_init (Optional[callable]): Optional callable to run at the end of
            trainer init that takes the trainer instance as argument.
        before_evaluate_fn (Optional[callable]): callback to run before
            evaluation. This takes the trainer instance as argument.
        mixins (list): list of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class.
        execution_plan (func): Setup the distributed execution workflow.

    Returns:
        a Trainer instance that uses the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)

            if get_policy_class is None:
                self._policy = default_policy
            else:
                self._policy = get_policy_class(config)
            if before_init:
                before_init(self)
            # Creating all workers (excluding evaluation workers).
            self.workers = self._make_workers(env_creator, self._policy,
                                              config,
                                              self.config["num_workers"])
            self.execution_plan = execution_plan
            self.train_exec_impl = execution_plan(self.workers, config)
            if after_init:
                after_init(self)

        @override(Trainer)
        def step(self):
            res = next(self.train_exec_impl)
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["train_exec_impl"] = (
                self.train_exec_impl.shared_metrics.get().save())
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.train_exec_impl.shared_metrics.get().restore(
                state["train_exec_impl"])

    def with_updates(**overrides):
        """Build a copy of this trainer with the specified overrides.

        Arguments:
            overrides (dict): use this to override any of the arguments
                originally passed to build_trainer() for this policy.
        """
        return build_trainer(**dict(original_kwargs, **overrides))

    trainer_cls.with_updates = staticmethod(with_updates)
    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#5
0
def build_eager_tf_policy(
    name,
    loss_fn,
    get_default_config=None,
    postprocess_fn=None,
    stats_fn=None,
    optimizer_fn=None,
    compute_gradients_fn=None,
    apply_gradients_fn=None,
    grad_stats_fn=None,
    extra_learn_fetches_fn=None,
    extra_action_out_fn=None,
    validate_spaces=None,
    before_init=None,
    before_loss_init=None,
    after_init=None,
    make_model=None,
    action_sampler_fn=None,
    action_distribution_fn=None,
    mixins=None,
    get_batch_divisibility_req=None,
    # Deprecated args.
    obs_include_prev_action_reward=DEPRECATED_VALUE,
    extra_action_fetches_fn=None,
    gradients_fn=None,
):
    """Build an eager TF policy.

    An eager policy runs all operations in eager mode, which makes debugging
    much simpler, but has lower performance.

    You shouldn't need to call this directly. Rather, prefer to build a TF
    graph policy and use set {"framework": "tfe"} in the trainer config to have
    it automatically be converted to an eager policy.

    This has the same signature as build_tf_policy()."""

    base = add_mixins(Policy, mixins)

    if obs_include_prev_action_reward != DEPRECATED_VALUE:
        deprecation_warning(old="obs_include_prev_action_reward", error=False)

    if extra_action_fetches_fn is not None:
        deprecation_warning(old="extra_action_fetches_fn",
                            new="extra_action_out_fn",
                            error=False)
        extra_action_out_fn = extra_action_fetches_fn

    if gradients_fn is not None:
        deprecation_warning(old="gradients_fn",
                            new="compute_gradients_fn",
                            error=False)
        compute_gradients_fn = gradients_fn

    class eager_policy_cls(base):
        def __init__(self, observation_space, action_space, config):
            # If this class runs as a @ray.remote actor, eager mode may not
            # have been activated yet.
            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)

            # Log device and worker index.
            from ray.rllib.evaluation.rollout_worker import get_global_worker

            worker = get_global_worker()
            worker_idx = worker.worker_index if worker else 0
            if get_gpu_devices():
                logger.info(
                    "TF-eager Policy (worker={}) running on GPU.".format(
                        worker_idx if worker_idx > 0 else "local"))
            else:
                logger.info(
                    "TF-eager Policy (worker={}) running on CPU.".format(
                        worker_idx if worker_idx > 0 else "local"))

            self._is_training = False

            # Only for `config.eager_tracing=True`: A counter to keep track of
            # how many times an eager-traced method (e.g.
            # `self._compute_actions_helper`) has been re-traced by tensorflow.
            # We will raise an error if more than n re-tracings have been
            # detected, since this would considerably slow down execution.
            # The variable below should only get incremented during the
            # tf.function trace operations, never when calling the already
            # traced function after that.
            self._re_trace_counter = 0

            self._loss_initialized = False
            # To ensure backward compatibility:
            # Old way: If `loss` provided here, use as-is (as a function).
            if loss_fn is not None:
                self._loss = loss_fn
            # New way: Convert the overridden `self.loss` into a plain
            # function, so it can be called the same way as `loss` would
            # be, ensuring backward compatibility.
            elif self.loss.__func__.__qualname__ != "Policy.loss":
                self._loss = self.loss.__func__
            # `loss` not provided nor overridden from Policy -> Set to None.
            else:
                self._loss = None

            self.batch_divisibility_req = (get_batch_divisibility_req(self) if
                                           callable(get_batch_divisibility_req)
                                           else
                                           (get_batch_divisibility_req or 1))
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Lock used for locking some methods on the object-level.
            # This prevents possible race conditions when calling the model
            # first, then its value function (e.g. in a loss function), in
            # between of which another model call is made (e.g. to compute an
            # action).
            self._lock = threading.RLock()

            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_inputs = self.model.get_initial_state()
            self._is_recurrent = len(self._state_inputs) > 0

            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(self.model.view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)

            # The list of local (tf) optimizers (one per loss term).
            self._optimizers: List[LocalOptimizer] = optimizers
            # Backward compatibility: A user's policy may only support a single
            # loss term and optimizer (no lists).
            self._optimizer: LocalOptimizer = optimizers[
                0] if optimizers else None

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after fake run-throughs.
            self.global_timestep = 0

        @override(Policy)
        def compute_actions_from_input_dict(
            self,
            input_dict: Dict[str, TensorType],
            explore: bool = None,
            timestep: Optional[int] = None,
            episodes: Optional[List[Episode]] = None,
            **kwargs,
        ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

            if not self.config.get(
                    "eager_tracing") and not tf1.executing_eagerly():
                tf1.enable_eager_execution()

            self._is_training = False

            explore = explore if explore is not None else self.config["explore"]
            timestep = timestep if timestep is not None else self.global_timestep
            if isinstance(timestep, tf.Tensor):
                timestep = int(timestep.numpy())

            # Pass lazy (eager) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            input_dict.set_training(False)

            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            self._state_in = state_batches
            self._is_recurrent = state_batches != []

            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(timestep=timestep,
                                                    explore=explore,
                                                    tf_sess=self.get_session())

            ret = self._compute_actions_helper(
                input_dict,
                state_batches,
                # TODO: Passing episodes into a traced method does not work.
                None if self.config["eager_tracing"] else episodes,
                explore,
                timestep,
            )
            # Update our global timestep by the batch size.
            self.global_timestep += int(tree.flatten(ret[0])[0].shape[0])
            return convert_to_numpy(ret)

        @override(Policy)
        def compute_actions(
            self,
            obs_batch,
            state_batches=None,
            prev_action_batch=None,
            prev_reward_batch=None,
            info_batch=None,
            episodes=None,
            explore=None,
            timestep=None,
            **kwargs,
        ):

            # Create input dict to simply pass the entire call to
            # self.compute_actions_from_input_dict().
            input_dict = SampleBatch(
                {
                    SampleBatch.CUR_OBS: obs_batch,
                },
                _is_training=tf.constant(False),
            )
            if state_batches is not None:
                for s in enumerate(state_batches):
                    input_dict["state_in_{i}"] = s
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            if info_batch is not None:
                input_dict[SampleBatch.INFOS] = info_batch

            return self.compute_actions_from_input_dict(
                input_dict=input_dict,
                explore=explore,
                timestep=timestep,
                episodes=episodes,
                **kwargs,
            )

        @with_lock
        @override(Policy)
        def compute_log_likelihoods(
            self,
            actions,
            obs_batch,
            state_batches=None,
            prev_action_batch=None,
            prev_reward_batch=None,
            actions_normalized=True,
        ):
            if action_sampler_fn and action_distribution_fn is None:
                raise ValueError("Cannot compute log-prob/likelihood w/o an "
                                 "`action_distribution_fn` and a provided "
                                 "`action_sampler_fn`!")

            seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
            input_batch = SampleBatch(
                {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
                _is_training=False,
            )
            if prev_action_batch is not None:
                input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
                    prev_action_batch)
            if prev_reward_batch is not None:
                input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
                    prev_reward_batch)

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if action_distribution_fn:
                dist_inputs, dist_class, _ = action_distribution_fn(
                    self,
                    self.model,
                    input_batch,
                    explore=False,
                    is_training=False)
            # Default log-likelihood calculation.
            else:
                dist_inputs, _ = self.model(input_batch, state_batches,
                                            seq_lens)
                dist_class = self.dist_class

            action_dist = dist_class(dist_inputs, self.model)

            # Normalize actions if necessary.
            if not actions_normalized and self.config["normalize_actions"]:
                actions = normalize_action(actions, self.action_space_struct)

            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            assert tf.executing_eagerly()
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @with_lock
        @override(Policy)
        def learn_on_batch(self, postprocessed_batch):
            # Callback handling.
            learn_stats = {}
            self.callbacks.on_learn_on_batch(policy=self,
                                             train_batch=postprocessed_batch,
                                             result=learn_stats)

            pad_batch_to_sequences_of_same_size(
                postprocessed_batch,
                max_seq_len=self._max_seq_len,
                shuffle=False,
                batch_divisibility_req=self.batch_divisibility_req,
                view_requirements=self.view_requirements,
            )

            self._is_training = True
            postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
            postprocessed_batch.set_training(True)
            stats = self._learn_on_batch_helper(postprocessed_batch)
            stats.update({
                "custom_metrics": learn_stats,
                NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
            })
            return convert_to_numpy(stats)

        @override(Policy)
        def compute_gradients(
            self, postprocessed_batch: SampleBatch
        ) -> Tuple[ModelGradients, Dict[str, TensorType]]:

            pad_batch_to_sequences_of_same_size(
                postprocessed_batch,
                shuffle=False,
                max_seq_len=self._max_seq_len,
                batch_divisibility_req=self.batch_divisibility_req,
                view_requirements=self.view_requirements,
            )

            self._is_training = True
            self._lazy_tensor_dict(postprocessed_batch)
            postprocessed_batch.set_training(True)
            grads_and_vars, grads, stats = self._compute_gradients_helper(
                postprocessed_batch)
            return convert_to_numpy((grads, stats))

        @override(Policy)
        def apply_gradients(self, gradients: ModelGradients) -> None:
            self._apply_gradients_helper(
                list(
                    zip(
                        [(tf.convert_to_tensor(g) if g is not None else None)
                         for g in gradients],
                        self.model.trainable_variables(),
                    )))

        @override(Policy)
        def get_weights(self, as_dict=False):
            variables = self.variables()
            if as_dict:
                return {v.name: v.numpy() for v in variables}
            return [v.numpy() for v in variables]

        @override(Policy)
        def set_weights(self, weights):
            variables = self.variables()
            assert len(weights) == len(variables), (len(weights),
                                                    len(variables))
            for v, w in zip(variables, weights):
                v.assign(w)

        @override(Policy)
        def get_exploration_state(self):
            return convert_to_numpy(self.exploration.get_state())

        @override(Policy)
        def is_recurrent(self):
            return self._is_recurrent

        @override(Policy)
        def num_state_tensors(self):
            return len(self._state_inputs)

        @override(Policy)
        def get_initial_state(self):
            if hasattr(self, "model"):
                return self.model.get_initial_state()
            return []

        @override(Policy)
        def get_state(self):
            state = super().get_state()
            if self._optimizer and len(self._optimizer.variables()) > 0:
                state["_optimizer_variables"] = self._optimizer.variables()
            # Add exploration state.
            state["_exploration_state"] = self.exploration.get_state()
            return state

        @override(Policy)
        def set_state(self, state):
            state = state.copy()  # shallow copy
            # Set optimizer vars first.
            optimizer_vars = state.get("_optimizer_variables", None)
            if optimizer_vars and self._optimizer.variables():
                logger.warning(
                    "Cannot restore an optimizer's state for tf eager! Keras "
                    "is not able to save the v1.x optimizers (from "
                    "tf.compat.v1.train) since they aren't compatible with "
                    "checkpoints.")
                for opt_var, value in zip(self._optimizer.variables(),
                                          optimizer_vars):
                    opt_var.assign(value)
            # Set exploration's state.
            if hasattr(self, "exploration") and "_exploration_state" in state:
                self.exploration.set_state(state=state["_exploration_state"])
            # Then the Policy's (NN) weights.
            super().set_state(state)

        @override(Policy)
        def export_checkpoint(self, export_dir):
            raise NotImplementedError  # TODO: implement this

        @override(Policy)
        def export_model(self, export_dir):
            raise NotImplementedError  # TODO: implement this

        def variables(self):
            """Return the list of all savable variables for this policy."""
            if isinstance(self.model, tf.keras.Model):
                return self.model.variables
            else:
                return self.model.variables()

        def loss_initialized(self):
            return self._loss_initialized

        @with_lock
        def _compute_actions_helper(self, input_dict, state_batches, episodes,
                                    explore, timestep):
            # Increase the tracing counter to make sure we don't re-trace too
            # often. If eager_tracing=True, this counter should only get
            # incremented during the @tf.function trace operations, never when
            # calling the already traced function after that.
            self._re_trace_counter += 1

            # Calculate RNN sequence lengths.
            batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
            seq_lens = tf.ones(batch_size,
                               dtype=tf.int32) if state_batches else None

            # Add default and custom fetches.
            extra_fetches = {}

            # Use Exploration object.
            with tf.variable_creator_scope(_disallow_var_creation):
                if action_sampler_fn:
                    dist_inputs = None
                    state_out = []
                    actions, logp = action_sampler_fn(
                        self,
                        self.model,
                        input_dict[SampleBatch.CUR_OBS],
                        explore=explore,
                        timestep=timestep,
                        episodes=episodes,
                    )
                else:
                    if action_distribution_fn:

                        # Try new action_distribution_fn signature, supporting
                        # state_batches and seq_lens.
                        try:
                            (
                                dist_inputs,
                                self.dist_class,
                                state_out,
                            ) = action_distribution_fn(
                                self,
                                self.model,
                                input_dict=input_dict,
                                state_batches=state_batches,
                                seq_lens=seq_lens,
                                explore=explore,
                                timestep=timestep,
                                is_training=False,
                            )
                        # Trying the old way (to stay backward compatible).
                        # TODO: Remove in future.
                        except TypeError as e:
                            if ("positional argument" in e.args[0]
                                    or "unexpected keyword argument"
                                    in e.args[0]):
                                (
                                    dist_inputs,
                                    self.dist_class,
                                    state_out,
                                ) = action_distribution_fn(
                                    self,
                                    self.model,
                                    input_dict[SampleBatch.OBS],
                                    explore=explore,
                                    timestep=timestep,
                                    is_training=False,
                                )
                            else:
                                raise e
                    elif isinstance(self.model, tf.keras.Model):
                        input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
                        if state_batches and "state_in_0" not in input_dict:
                            for i, s in enumerate(state_batches):
                                input_dict[f"state_in_{i}"] = s
                        self._lazy_tensor_dict(input_dict)
                        dist_inputs, state_out, extra_fetches = self.model(
                            input_dict)
                    else:
                        dist_inputs, state_out = self.model(
                            input_dict, state_batches, seq_lens)

                    action_dist = self.dist_class(dist_inputs, self.model)

                    # Get the exploration action from the forward results.
                    actions, logp = self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore,
                    )

            # Action-logp and action-prob.
            if logp is not None:
                extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            # Custom extra fetches.
            if extra_action_out_fn:
                extra_fetches.update(extra_action_out_fn(self))

            return actions, state_out, extra_fetches

        def _learn_on_batch_helper(self, samples):
            # Increase the tracing counter to make sure we don't re-trace too
            # often. If eager_tracing=True, this counter should only get
            # incremented during the @tf.function trace operations, never when
            # calling the already traced function after that.
            self._re_trace_counter += 1

            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, _, stats = self._compute_gradients_helper(
                    samples)
            self._apply_gradients_helper(grads_and_vars)
            return stats

        def _get_is_training_placeholder(self):
            return tf.convert_to_tensor(self._is_training)

        @with_lock
        def _compute_gradients_helper(self, samples):
            """Computes and returns grads as eager tensors."""

            # Increase the tracing counter to make sure we don't re-trace too
            # often. If eager_tracing=True, this counter should only get
            # incremented during the @tf.function trace operations, never when
            # calling the already traced function after that.
            self._re_trace_counter += 1

            # Gather all variables for which to calculate losses.
            if isinstance(self.model, tf.keras.Model):
                variables = self.model.trainable_variables
            else:
                variables = self.model.trainable_variables()

            # Calculate the loss(es) inside a tf GradientTape.
            with tf.GradientTape(
                    persistent=compute_gradients_fn is not None) as tape:
                losses = self._loss(self, self.model, self.dist_class, samples)
            losses = force_list(losses)

            # User provided a compute_gradients_fn.
            if compute_gradients_fn:
                # Wrap our tape inside a wrapper, such that the resulting
                # object looks like a "classic" tf.optimizer. This way, custom
                # compute_gradients_fn will work on both tf static graph
                # and tf-eager.
                optimizer = OptimizerWrapper(tape)
                # More than one loss terms/optimizers.
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    grads_and_vars = compute_gradients_fn(
                        self, [optimizer] * len(losses), losses)
                # Only one loss and one optimizer.
                else:
                    grads_and_vars = [
                        compute_gradients_fn(self, optimizer, losses[0])
                    ]
            # Default: Compute gradients using the above tape.
            else:
                grads_and_vars = [
                    list(zip(tape.gradient(loss, variables), variables))
                    for loss in losses
                ]

            if log_once("grad_vars"):
                for g_and_v in grads_and_vars:
                    for g, v in g_and_v:
                        if g is not None:
                            logger.info(f"Optimizing variable {v.name}")

            # `grads_and_vars` is returned a list (len=num optimizers/losses)
            # of lists of (grad, var) tuples.
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
            # `grads_and_vars` is returned as a list of (grad, var) tuples.
            else:
                grads_and_vars = grads_and_vars[0]
                grads = [g for g, _ in grads_and_vars]

            stats = self._stats(self, samples, grads)
            return grads_and_vars, grads, stats

        def _apply_gradients_helper(self, grads_and_vars):
            # Increase the tracing counter to make sure we don't re-trace too
            # often. If eager_tracing=True, this counter should only get
            # incremented during the @tf.function trace operations, never when
            # calling the already traced function after that.
            self._re_trace_counter += 1

            if apply_gradients_fn:
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    apply_gradients_fn(self, self._optimizers, grads_and_vars)
                else:
                    apply_gradients_fn(self, self._optimizer, grads_and_vars)
            else:
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    for i, o in enumerate(self._optimizers):
                        o.apply_gradients([(g, v) for g, v in grads_and_vars[i]
                                           if g is not None])
                else:
                    self._optimizer.apply_gradients([(g, v)
                                                     for g, v in grads_and_vars
                                                     if g is not None])

        def _stats(self, outputs, samples, grads):

            fetches = {}
            if stats_fn:
                fetches[LEARNER_STATS_KEY] = {
                    k: v
                    for k, v in stats_fn(outputs, samples).items()
                }
            else:
                fetches[LEARNER_STATS_KEY] = {}

            if extra_learn_fetches_fn:
                fetches.update(
                    {k: v
                     for k, v in extra_learn_fetches_fn(self).items()})
            if grad_stats_fn:
                fetches.update({
                    k: v
                    for k, v in grad_stats_fn(self, samples, grads).items()
                })
            return fetches

        def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
            # TODO: (sven): Keep for a while to ensure backward compatibility.
            if not isinstance(postprocessed_batch, SampleBatch):
                postprocessed_batch = SampleBatch(postprocessed_batch)
            postprocessed_batch.set_get_interceptor(_convert_to_tf)
            return postprocessed_batch

        @classmethod
        def with_tracing(cls):
            return traced_eager_policy(cls)

    eager_policy_cls.__name__ = name + "_eager"
    eager_policy_cls.__qualname__ = name + "_eager"
    return eager_policy_cls
示例#6
0
def build_torch_policy(name,
                       loss_fn,
                       get_default_config=None,
                       stats_fn=None,
                       postprocess_fn=None,
                       extra_action_out_fn=None,
                       extra_grad_process_fn=None,
                       optimizer_fn=None,
                       before_init=None,
                       after_init=None,
                       make_model_and_action_dist=None,
                       mixins=None):
    """Helper function for creating a torch policy at runtime.

    Arguments:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        loss_fn (func): function that returns a loss tensor as arguments
            (policy, model, dist_class, train_batch)
        get_default_config (func): optional function that returns the default
            config to merge with any overrides
        stats_fn (func): optional function that returns a dict of
            values given the policy and batch input tensors
        postprocess_fn (func): optional experience postprocessing function
            that takes the same args as Policy.postprocess_trajectory()
        extra_action_out_fn (func): optional function that returns
            a dict of extra values to include in experiences
        extra_grad_process_fn (func): optional function that is called after
            gradients are computed and returns processing info
        optimizer_fn (func): optional function that returns a torch optimizer
            given the policy and config
        before_init (func): optional function to run at the beginning of
            policy init that takes the same arguments as the policy constructor
        after_init (func): optional function to run at the end of policy init
            that takes the same arguments as the policy constructor
        make_model_and_action_dist (func): optional func that takes the same
            arguments as policy init and returns a tuple of model instance and
            torch action distribution class. If not specified, the default
            model and action dist from the catalog will be used
        mixins (list): list of any class mixins for the returned policy class.
            These mixins will be applied in order and will have higher
            precedence than the TorchPolicy class

    Returns:
        a TorchPolicy instance that uses the specified args
    """

    original_kwargs = locals().copy()
    base = add_mixins(TorchPolicy, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
                # Make sure, we passed in a correct Model factory.
                assert isinstance(self.model, TorchModelV2), \
                    "ERROR: TorchPolicy::make_model_and_action_dist must " \
                    "return a TorchModelV2 object!"
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(obs_space,
                                                       action_space,
                                                       logit_dim,
                                                       self.config["model"],
                                                       framework="torch")

            TorchPolicy.__init__(self, obs_space, action_space, config,
                                 self.model, loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            if not postprocess_fn:
                return sample_batch

            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak (issue #6962).
            with torch.no_grad():
                return postprocess_fn(
                    self, convert_to_non_torch_type(sample_batch),
                    convert_to_non_torch_type(other_agent_batches), episode)

        @override(TorchPolicy)
        def extra_grad_process(self):
            if extra_grad_process_fn:
                return extra_grad_process_fn(self)
            else:
                return TorchPolicy.extra_grad_process(self)

        @override(TorchPolicy)
        def extra_action_out(self,
                             input_dict,
                             state_batches,
                             model,
                             action_dist=None):
            with torch.no_grad():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(self, input_dict,
                                                     state_batches, model,
                                                     action_dist)
                else:
                    stats_dict = TorchPolicy.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return convert_to_non_torch_type(stats_dict)

        @override(TorchPolicy)
        def optimizer(self):
            if optimizer_fn:
                return optimizer_fn(self, self.config)
            else:
                return TorchPolicy.optimizer(self)

        @override(TorchPolicy)
        def extra_grad_info(self, train_batch):
            with torch.no_grad():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
                return convert_to_non_torch_type(stats_dict)

    def with_updates(**overrides):
        return build_torch_policy(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#7
0
def build_tf_policy(name,
                    loss_fn,
                    get_default_config=None,
                    postprocess_fn=None,
                    stats_fn=None,
                    optimizer_fn=None,
                    gradients_fn=None,
                    apply_gradients_fn=None,
                    grad_stats_fn=None,
                    extra_action_fetches_fn=None,
                    extra_action_feed_fn=None,
                    extra_learn_fetches_fn=None,
                    extra_learn_feed_fn=None,
                    before_init=None,
                    before_loss_init=None,
                    after_init=None,
                    make_model=None,
                    action_sampler_fn=None,
                    mixins=None,
                    get_batch_divisibility_req=None,
                    obs_include_prev_action_reward=True):
    """Helper function for creating a dynamic tf policy at runtime.

    Functions will be run in this order to initialize the policy:
        1. Placeholder setup: postprocess_fn
        2. Loss init: loss_fn, stats_fn
        3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
                           grad_stats_fn

    This means that you can e.g., depend on any policy attributes created in
    the running of `loss_fn` in later functions such as `stats_fn`.

    In eager mode (to be implemented), the following functions will be run
    repeatedly on each eager execution: loss_fn, stats_fn

    This means that these functions should not define any variables internally,
    otherwise they will fail in eager mode execution. Variable should only
    be created in make_model (if defined).

    Arguments:
        name (str): name of the policy (e.g., "PPOTFPolicy")
        loss_fn (func): function that returns a loss tensor the policy,
            and dict of experience tensor placeholdes
        get_default_config (func): optional function that returns the default
            config to merge with any overrides
        postprocess_fn (func): optional experience postprocessing function
            that takes the same args as Policy.postprocess_trajectory()
        stats_fn (func): optional function that returns a dict of
            TF fetches given the policy and batch input tensors
        optimizer_fn (func): optional function that returns a tf.Optimizer
            given the policy and config
        gradients_fn (func): optional function that returns a list of gradients
            given (policy, optimizer, loss). If not specified, this
            defaults to optimizer.compute_gradients(loss)
        apply_gradients_fn (func): optional function that returns an apply
            gradients op given (policy, optimizer, grads_and_vars)
        grad_stats_fn (func): optional function that returns a dict of
            TF fetches given the policy and loss gradient tensors
        extra_action_fetches_fn (func): optional function that returns
            a dict of TF fetches given the policy object
        extra_action_feed_fn (func): optional function that returns a feed dict
            to also feed to TF when computing actions
        extra_learn_fetches_fn (func): optional function that returns a dict of
            extra values to fetch and return when learning on a batch
        extra_learn_feed_fn (func): optional function that returns a feed dict
            to also feed to TF when learning on a batch
        before_init (func): optional function to run at the beginning of
            policy init that takes the same arguments as the policy constructor
        before_loss_init (func): optional function to run prior to loss
            init that takes the same arguments as the policy constructor
        after_init (func): optional function to run at the end of policy init
            that takes the same arguments as the policy constructor
        make_model (func): optional function that returns a ModelV2 object
            given (policy, obs_space, action_space, config).
            All policy variables should be created in this function. If not
            specified, a default model will be created.
        action_sampler_fn (func): optional function that returns a
            tuple of action and action prob tensors given
            (policy, model, input_dict, obs_space, action_space, config).
            If not specified, a default action distribution will be used.
        mixins (list): list of any class mixins for the returned policy class.
            These mixins will be applied in order and will have higher
            precedence than the DynamicTFPolicy class
        get_batch_divisibility_req (func): optional function that returns
            the divisibility requirement for sample batches
        obs_include_prev_action_reward (bool): whether to include the
            previous action and reward in the model input

    Returns:
        a DynamicTFPolicy instance that uses the specified args
    """

    original_kwargs = locals().copy()
    base = add_mixins(DynamicTFPolicy, mixins)

    class policy_cls(base):
        def __init__(self,
                     obs_space,
                     action_space,
                     config,
                     existing_model=None,
                     existing_inputs=None):
            if get_default_config:
                config = dict(get_default_config(), **config)

            if before_init:
                before_init(self, obs_space, action_space, config)

            def before_loss_init_wrapper(policy, obs_space, action_space,
                                         config):
                if before_loss_init:
                    before_loss_init(policy, obs_space, action_space, config)
                if extra_action_fetches_fn is None:
                    self._extra_action_fetches = {}
                else:
                    self._extra_action_fetches = extra_action_fetches_fn(self)

            DynamicTFPolicy.__init__(
                self,
                obs_space,
                action_space,
                config,
                loss_fn,
                stats_fn=stats_fn,
                grad_stats_fn=grad_stats_fn,
                before_loss_init=before_loss_init_wrapper,
                make_model=make_model,
                action_sampler_fn=action_sampler_fn,
                existing_model=existing_model,
                existing_inputs=existing_inputs,
                get_batch_divisibility_req=get_batch_divisibility_req,
                obs_include_prev_action_reward=obs_include_prev_action_reward)

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            if not postprocess_fn:
                return sample_batch
            return postprocess_fn(self, sample_batch, other_agent_batches,
                                  episode)

        @override(TFPolicy)
        def optimizer(self):
            if optimizer_fn:
                return optimizer_fn(self, self.config)
            else:
                return TFPolicy.optimizer(self)

        @override(TFPolicy)
        def gradients(self, optimizer, loss):
            if gradients_fn:
                return gradients_fn(self, optimizer, loss)
            else:
                return TFPolicy.gradients(self, optimizer, loss)

        @override(TFPolicy)
        def build_apply_op(self, optimizer, grads_and_vars):
            if apply_gradients_fn:
                return apply_gradients_fn(self, optimizer, grads_and_vars)
            else:
                return TFPolicy.build_apply_op(self, optimizer, grads_and_vars)

        @override(TFPolicy)
        def extra_compute_action_fetches(self):
            return dict(TFPolicy.extra_compute_action_fetches(self),
                        **self._extra_action_fetches)

        @override(TFPolicy)
        def extra_compute_action_feed_dict(self):
            if extra_action_feed_fn:
                return extra_action_feed_fn(self)
            else:
                return TFPolicy.extra_compute_action_feed_dict(self)

        @override(TFPolicy)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                # auto-add empty learner stats dict if needed
                return dict({LEARNER_STATS_KEY: {}},
                            **extra_learn_fetches_fn(self))
            else:
                return TFPolicy.extra_compute_grad_fetches(self)

        @override(TFPolicy)
        def extra_compute_grad_feed_dict(self):
            if extra_learn_feed_fn:
                return extra_learn_feed_fn(self)
            else:
                return TFPolicy.extra_compute_grad_feed_dict(self)

    @staticmethod
    def with_updates(**overrides):
        return build_tf_policy(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = with_updates
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#8
0
def build_tf_policy(name: str,
                    *,
                    loss_fn: Callable[
                        [Policy, ModelV2, type, SampleBatch], TensorType],
                    get_default_config: Optional[
                        Callable[[None], TrainerConfigDict]] = None,
                    postprocess_fn: Optional[Callable[
                        [Policy, SampleBatch, List[SampleBatch],
                         "MultiAgentEpisode"], None]] = None,
                    stats_fn: Optional[Callable[
                        [Policy, SampleBatch], Dict[str, TensorType]]] = None,
                    optimizer_fn: Optional[Callable[
                        [Policy, TrainerConfigDict],
                        "tf.keras.optimizers.Optimizer"]] = None,
                    gradients_fn: Optional[Callable[
                        [Policy, "tf.keras.optimizers.Optimizer",
                         TensorType], ModelGradients]] = None,
                    apply_gradients_fn: Optional[Callable[
                        [Policy, "tf.keras.optimizers.Optimizer",
                         ModelGradients], "tf.Operation"]] = None,
                    grad_stats_fn: Optional[Callable[
                        [Policy, SampleBatch, ModelGradients],
                        Dict[str, TensorType]]] = None,
                    extra_action_fetches_fn: Optional[Callable[
                        [Policy], Dict[str, TensorType]]] = None,
                    extra_learn_fetches_fn: Optional[Callable[
                        [Policy], Dict[str, TensorType]]] = None,
                    validate_spaces: Optional[Callable[
                        [Policy, gym.Space, gym.Space, TrainerConfigDict],
                        None]] = None,
                    before_init: Optional[Callable[
                        [Policy, gym.Space, gym.Space, TrainerConfigDict],
                        None]] = None,
                    before_loss_init: Optional[Callable[
                        [Policy, gym.spaces.Space, gym.spaces.Space,
                         TrainerConfigDict], None]] = None,
                    after_init: Optional[Callable[
                        [Policy, gym.Space, gym.Space, TrainerConfigDict],
                        None]] = None,
                    make_model: Optional[Callable[
                        [Policy, gym.spaces.Space, gym.spaces.Space,
                         TrainerConfigDict], ModelV2]] = None,
                    action_sampler_fn: Optional[Callable[
                        [TensorType, List[TensorType]], Tuple[
                            TensorType, TensorType]]] = None,
                    action_distribution_fn: Optional[Callable[
                        [Policy, ModelV2, TensorType, TensorType, TensorType],
                        Tuple[TensorType, type, List[TensorType]]]] = None,
                    mixins: Optional[List[type]] = None,
                    get_batch_divisibility_req: Optional[Callable[
                        [Policy], int]] = None,
                    obs_include_prev_action_reward: bool = True):
    """Helper function for creating a dynamic tf policy at runtime.

    Functions will be run in this order to initialize the policy:
        1. Placeholder setup: postprocess_fn
        2. Loss init: loss_fn, stats_fn
        3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
                           grad_stats_fn

    This means that you can e.g., depend on any policy attributes created in
    the running of `loss_fn` in later functions such as `stats_fn`.

    In eager mode, the following functions will be run repeatedly on each
    eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
    and grad_stats_fn.

    This means that these functions should not define any variables internally,
    otherwise they will fail in eager mode execution. Variable should only
    be created in make_model (if defined).

    Args:
        name (str): Name of the policy (e.g., "PPOTFPolicy").
        loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
            Callable for calculating a loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            List[SampleBatch], MultiAgentEpisode], None]]): Optional callable
            for post-processing experience batches (called after the
            super's `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            TF tensors to fetch given the policy and batch input tensors. If
            None, will not compute any stats.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "tf.keras.optimizers.Optimizer"]]): Optional callable that returns
            a tf.Optimizer given the policy and config. If None, will call
            the base class' `optimizer()` method instead (which returns a
            tf1.train.AdamOptimizer).
        gradients_fn (Optional[Callable[[Policy,
            "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
            Optional callable that returns a list of gradients. If None,
            this defaults to optimizer.compute_gradients([loss]).
        apply_gradients_fn (Optional[Callable[[Policy,
            "tf.keras.optimizers.Optimizer", ModelGradients],
            "tf.Operation"]]): Optional callable that returns an apply
            gradients op given policy, tf-optimizer, and grads_and_vars. If
            None, will call the base class' `build_apply_op()` method instead.
        grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            TF fetches given the policy, batch input, and gradient tensors. If
            None, will not collect any gradient stats.
        extra_action_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns
            a dict of TF fetches given the policy object. If None, will not
            perform any extra fetches.
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra values to fetch and return when learning on a batch. If None,
            will call the base class' `extra_compute_grad_fetches()` method
            instead.
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check
            the spaces for correctness. If None, no spaces checking will be
            done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of policy init that takes the same arguments as the
            policy constructor. If None, this step will be skipped.
        before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run prior to loss init. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the end of
            policy init. If None, this step will be skipped.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that returns a ModelV2 object.
            All policy variables should be created in this function. If None,
            a default ModelV2 object will be created.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): A callable returning a sampled
            action and its log-likelihood given observation and state inputs.
            If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType],
            Tuple[TensorType, type, List[TensorType]]]]): Optional callable
            returning distribution inputs (parameters), a dist-class to
            generate an action distribution object from, and internal-state
            outputs (or an empty list if not applicable). If None, will either
            use `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the so parameterized action distribution.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the DynamicTFPolicy class.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.
        obs_include_prev_action_reward (bool): Whether to include the
            previous action and reward in the model input.

    Returns:
        a DynamicTFPolicy instance that uses the specified args
    """
    original_kwargs = locals().copy()
    base = add_mixins(DynamicTFPolicy, mixins)

    class policy_cls(base):
        def __init__(self,
                     obs_space,
                     action_space,
                     config,
                     existing_model=None,
                     existing_inputs=None):
            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, config)

            if before_init:
                before_init(self, obs_space, action_space, config)

            def before_loss_init_wrapper(policy, obs_space, action_space,
                                         config):
                if before_loss_init:
                    before_loss_init(policy, obs_space, action_space, config)
                if extra_action_fetches_fn is None:
                    self._extra_action_fetches = {}
                else:
                    self._extra_action_fetches = extra_action_fetches_fn(self)

            DynamicTFPolicy.__init__(
                self,
                obs_space=obs_space,
                action_space=action_space,
                config=config,
                loss_fn=loss_fn,
                stats_fn=stats_fn,
                grad_stats_fn=grad_stats_fn,
                before_loss_init=before_loss_init_wrapper,
                make_model=make_model,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                existing_model=existing_model,
                existing_inputs=existing_inputs,
                get_batch_divisibility_req=get_batch_divisibility_req,
                obs_include_prev_action_reward=obs_include_prev_action_reward)

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @override(TFPolicy)
        def optimizer(self):
            if optimizer_fn:
                return optimizer_fn(self, self.config)
            else:
                return base.optimizer(self)

        @override(TFPolicy)
        def gradients(self, optimizer, loss):
            if gradients_fn:
                return gradients_fn(self, optimizer, loss)
            else:
                return base.gradients(self, optimizer, loss)

        @override(TFPolicy)
        def build_apply_op(self, optimizer, grads_and_vars):
            if apply_gradients_fn:
                return apply_gradients_fn(self, optimizer, grads_and_vars)
            else:
                return base.build_apply_op(self, optimizer, grads_and_vars)

        @override(TFPolicy)
        def extra_compute_action_fetches(self):
            return dict(
                base.extra_compute_action_fetches(self),
                **self._extra_action_fetches)

        @override(TFPolicy)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                # Auto-add empty learner stats dict if needed.
                return dict({
                    LEARNER_STATS_KEY: {}
                }, **extra_learn_fetches_fn(self))
            else:
                return base.extra_compute_grad_fetches(self)

    def with_updates(**overrides):
        """Allows creating a TFPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_tf_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new TFPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_tf_policy(**dict(original_kwargs, **overrides))

    def as_eager():
        return eager_tf_policy.build_eager_tf_policy(**original_kwargs)

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.as_eager = staticmethod(as_eager)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
def build_torch_policy(
        name: str,
        *,
        loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType],
        get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
        stats_fn: Optional[Callable[[Policy, SampleBatch],
                                    Dict[str, TensorType]]] = None,
        postprocess_fn: Optional[Callable[
            [Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode"],
            None]] = None,
        extra_action_out_fn: Optional[Callable[[
            Policy, Dict[str, TensorType], List[TensorType], ModelV2,
            TorchDistributionWrapper
        ], Dict[str, TensorType]]] = None,
        extra_grad_process_fn: Optional[
            Callable[[Policy, "torch.optim.Optimizer", TensorType],
                     Dict[str, TensorType]]] = None,
        # TODO: (sven) Replace "fetches" with "process".
        extra_learn_fetches_fn: Optional[Callable[[Policy],
                                                  Dict[str,
                                                       TensorType]]] = None,
        optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
                                        "torch.optim.Optimizer"]] = None,
        validate_spaces: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        after_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]]] = None,
        action_distribution_fn: Optional[
            Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, type, List[TensorType]]]] = None,
        make_model: Optional[Callable[
            [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            ModelV2]] = None,
        make_model_and_action_dist: Optional[Callable[
            [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            Tuple[ModelV2, TorchDistributionWrapper]]] = None,
        apply_gradients_fn: Optional[Callable[
            [Policy, "torch.optim.Optimizer"], None]] = None,
        mixins: Optional[List[type]] = None,
        training_view_requirements_fn: Optional[Callable[[], Dict[
            str, ViewRequirement]]] = None,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None):
    """Helper function for creating a torch policy class at runtime.

    Args:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
            Callable that returns a loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            List[SampleBatch], MultiAgentEpisode], None]]): Optional callable
            for post-processing experience batches (called after the
            super's `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            values given the policy and batch input tensors. If None,
            will use `TorchPolicy.extra_grad_info()` instead.
        extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType,
            List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
            TensorType]]]): Optional callable that returns a dict of extra
            values to include in experiences. If None, no extra computations
            will be performed.
        extra_grad_process_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
            Optional callable that is called after gradients are computed and
            returns a processing info dict. If None, will call the
            `TorchPolicy.extra_grad_process()` method instead.
        # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra tensors from the policy after loss evaluation. If None,
            will call the `TorchPolicy.extra_compute_grad_fetches()` method
            instead.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "torch.optim.Optimizer"]]): Optional callable that returns a
            torch optimizer given the policy and config. If None, will call
            the `TorchPolicy.optimizer()` method instead (which returns a
            torch Adam optimizer).
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check for
            correctness. If None, no spaces checking will be done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of `Policy.__init__` that takes the same arguments as
            the Policy constructor. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the end of
            policy init that takes the same arguments as the policy
            constructor. If None, this step will be skipped.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): Optional callable returning a
            sampled action and its log-likelihood given some (obs and state)
            inputs. If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType], Tuple[TensorType, type,
            List[TensorType]]]]): A callable that takes
            the Policy, Model, the observation batch, an explore-flag, a
            timestep, and an is_training flag and returns a tuple of
            a) distribution inputs (parameters), b) a dist-class to generate
            an action distribution object from, and c) internal-state outputs
            (empty list if not applicable). If None, will either use
            `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the parameterized action distribution.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that takes the same arguments as Policy.__init__ and returns a
            model instance. The distribution class will be determined
            automatically. Note: Only one of `make_model` or
            `make_model_and_action_dist` should be provided. If both are None,
            a default Model will be created.
        make_model_and_action_dist (Optional[Callable[[Policy,
            gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            Tuple[ModelV2, TorchDistributionWrapper]]]): Optional callable that
            takes the same arguments as Policy.__init__ and returns a tuple
            of model instance and torch action distribution class.
            Note: Only one of `make_model` or `make_model_and_action_dist`
            should be provided. If both are None, a default Model will be
            created.
        apply_gradients_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer"], None]]): Optional callable that
            takes a grads list and applies these to the Model's parameters.
            If None, will call the `TorchPolicy.apply_gradients()` method
            instead.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the TorchPolicy class.
        training_view_requirements_fn (Callable[[],
            Dict[str, ViewRequirement]]): An optional callable to retrieve
            additional train view requirements for this policy.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.

    Returns:
        type: TorchPolicy child class constructed from the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(TorchPolicy, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch",
                    **self.config["model"].get("custom_model_config", {}))

            # Make sure, we passed in a correct Model factory.
            assert isinstance(self.model, TorchModelV2), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(TorchPolicy)
        def training_view_requirements(self):
            req = super().training_view_requirements()
            if callable(training_view_requirements_fn):
                req.update(training_view_requirements_fn(self))
            return req

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak (issue #6962).
            with torch.no_grad():
                # Call super's postprocess_trajectory first.
                sample_batch = super().postprocess_trajectory(
                    convert_to_non_torch_type(sample_batch),
                    convert_to_non_torch_type(other_agent_batches), episode)
                if postprocess_fn:
                    return postprocess_fn(self, sample_batch,
                                          other_agent_batches, episode)

                return sample_batch

        @override(TorchPolicy)
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return TorchPolicy.extra_grad_process(self, optimizer, loss)

        @override(TorchPolicy)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                fetches = convert_to_non_torch_type(
                    extra_learn_fetches_fn(self))
                # Auto-add empty learner stats dict if needed.
                return dict({LEARNER_STATS_KEY: {}}, **fetches)
            else:
                return TorchPolicy.extra_compute_grad_fetches(self)

        @override(TorchPolicy)
        def apply_gradients(self, gradients):
            if apply_gradients_fn:
                apply_gradients_fn(self, gradients)
            else:
                TorchPolicy.apply_gradients(self, gradients)

        @override(TorchPolicy)
        def extra_action_out(self, input_dict, state_batches, model,
                             action_dist):
            with torch.no_grad():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(self, input_dict,
                                                     state_batches, model,
                                                     action_dist)
                else:
                    stats_dict = TorchPolicy.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return convert_to_non_torch_type(stats_dict)

        @override(TorchPolicy)
        def optimizer(self):
            if optimizer_fn:
                optimizers = optimizer_fn(self, self.config)
            else:
                optimizers = TorchPolicy.optimizer(self)
            optimizers = force_list(optimizers)
            if hasattr(self, "exploration"):
                exploration_optimizers = force_list(
                    self.exploration.get_exploration_optimizer(self.config))
                optimizers.extend(exploration_optimizers)
            return optimizers

        @override(TorchPolicy)
        def extra_grad_info(self, train_batch):
            with torch.no_grad():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
                return convert_to_non_torch_type(stats_dict)

    def with_updates(**overrides):
        """Allows creating a TorchPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_torch_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new TorchPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_torch_policy(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#10
0
def build_torch_policy(name,
                       *,
                       loss_fn,
                       get_default_config=None,
                       stats_fn=None,
                       postprocess_fn=None,
                       extra_action_out_fn=None,
                       extra_grad_process_fn=None,
                       extra_learn_fetches_fn=None,
                       optimizer_fn=None,
                       validate_spaces=None,
                       before_init=None,
                       after_init=None,
                       action_sampler_fn=None,
                       action_distribution_fn=None,
                       make_model=None,
                       make_model_and_action_dist=None,
                       apply_gradients_fn=None,
                       mixins=None,
                       get_batch_divisibility_req=None):
    """Helper function for creating a torch policy class at runtime.

    Arguments:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        loss_fn (callable): Callable that returns a loss tensor as arguments
            given (policy, model, dist_class, train_batch).
        get_default_config (Optional[callable]): Optional callable that returns
            the default config to merge with any overrides.
        stats_fn (Optional[callable]): Optional callable that returns a dict of
            values given the policy and batch input tensors.
        postprocess_fn (Optional[callable]): Optional experience postprocessing
            function that takes the same args as
            Policy.postprocess_trajectory().
        extra_action_out_fn (Optional[callable]): Optional callable that
            returns a dict of extra values to include in experiences.
        extra_grad_process_fn (Optional[callable]): Optional callable that is
            called after gradients are computed and returns processing info.
        extra_learn_fetches_fn (func): optional function that returns a dict of
            extra values to fetch from the policy after loss evaluation.
        optimizer_fn (Optional[callable]): Optional callable that returns a
            torch optimizer given the policy and config.
        validate_spaces (Optional[callable]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check for
            correctness.
        before_init (Optional[callable]): Optional callable to run at the
            beginning of `Policy.__init__` that takes the same arguments as
            the Policy constructor.
        after_init (Optional[callable]): Optional callable to run at the end of
            policy init that takes the same arguments as the policy
            constructor.
        action_sampler_fn (Optional[callable]): Optional callable returning a
            sampled action and its log-likelihood given some (obs and state)
            inputs.
        action_distribution_fn (Optional[callable]): A callable that takes
            the Policy, Model, the observation batch, an explore-flag, a
            timestep, and an is_training flag and returns a tuple of
            a) distribution inputs (parameters), b) a dist-class to generate
            an action distribution object from, and c) internal-state outputs
            (empty list if not applicable).
        make_model (Optional[callable]): Optional func that
            takes the same arguments as Policy.__init__ and returns a model
            instance. The distribution class will be determined automatically.
            Note: Only one of `make_model` or `make_model_and_action_dist`
            should be provided.
        make_model_and_action_dist (Optional[callable]): Optional func that
            takes the same arguments as Policy.__init__ and returns a tuple
            of model instance and torch action distribution class.
            Note: Only one of `make_model` or `make_model_and_action_dist`
            should be provided.
        apply_gradients_fn (Optional[callable]): Optional callable that
            takes a grads list and applies these to the Model's parameters.
        mixins (list): list of any class mixins for the returned policy class.
            These mixins will be applied in order and will have higher
            precedence than the TorchPolicy class.
        get_batch_divisibility_req (Optional[callable]): Optional callable that
            returns the divisibility requirement for sample batches.

    Returns:
        type: TorchPolicy child class constructed from the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(TorchPolicy, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch",
                    **self.config["model"].get("custom_model_config", {}))

            # Make sure, we passed in a correct Model factory.
            assert isinstance(self.model, TorchModelV2), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak (issue #6962).
            with torch.no_grad():
                # Call super's postprocess_trajectory first.
                sample_batch = super().postprocess_trajectory(
                    convert_to_non_torch_type(sample_batch),
                    convert_to_non_torch_type(other_agent_batches), episode)
                if postprocess_fn:
                    return postprocess_fn(self, sample_batch,
                                          other_agent_batches, episode)

                return sample_batch

        @override(TorchPolicy)
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return TorchPolicy.extra_grad_process(self, optimizer, loss)

        @override(TorchPolicy)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                fetches = convert_to_non_torch_type(
                    extra_learn_fetches_fn(self))
                # Auto-add empty learner stats dict if needed.
                return dict({LEARNER_STATS_KEY: {}}, **fetches)
            else:
                return TorchPolicy.extra_compute_grad_fetches(self)

        @override(TorchPolicy)
        def apply_gradients(self, gradients):
            if apply_gradients_fn:
                apply_gradients_fn(self, gradients)
            else:
                TorchPolicy.apply_gradients(self, gradients)

        @override(TorchPolicy)
        def extra_action_out(self, input_dict, state_batches, model,
                             action_dist):
            with torch.no_grad():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(
                        self, input_dict, state_batches, model, action_dist)
                else:
                    stats_dict = TorchPolicy.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return convert_to_non_torch_type(stats_dict)

        @override(TorchPolicy)
        def optimizer(self):
            if optimizer_fn:
                return optimizer_fn(self, self.config)
            else:
                return TorchPolicy.optimizer(self)

        @override(TorchPolicy)
        def extra_grad_info(self, train_batch):
            with torch.no_grad():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
                return convert_to_non_torch_type(stats_dict)

    def with_updates(**overrides):
        return build_torch_policy(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
def build_trainer(
        name,
        default_policy,
        default_config=None,
        validate_config=None,
        get_initial_state=None,  # DEPRECATED
        get_policy_class=None,
        before_init=None,
        make_workers=None,  # DEPRECATED
        make_policy_optimizer=None,  # DEPRECATED
        after_init=None,
        before_train_step=None,  # DEPRECATED
        after_optimizer_step=None,  # DEPRECATED
        after_train_result=None,  # DEPRECATED
        collect_metrics_fn=None,  # DEPRECATED
        before_evaluate_fn=None,
        mixins=None,
        execution_plan=default_execution_plan):
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_policy
        2. Worker setup: before_init, execution_plan
        3. Post setup: after_init

    Arguments:
        name (str): name of the trainer (e.g., "PPO")
        default_policy (cls): the default Policy class to use
        default_config (dict): The default config dict of the algorithm,
            otherwise uses the Trainer default config.
        validate_config (func): optional callback that checks a given config
            for correctness. It may mutate the config as needed.
        get_policy_class (func): optional callback that takes a config and
            returns the policy class to override the default with
        before_init (func): optional function to run at the start of trainer
            init that takes the trainer instance as argument
        after_init (func): optional function to run at the end of trainer init
            that takes the trainer instance as argument
        before_evaluate_fn (func): callback to run before evaluation. This
            takes the trainer instance as argument.
        mixins (list): list of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class
        execution_plan (func): Setup the distributed execution workflow.

    Returns:
        a Trainer instance that uses the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)

            if get_initial_state:
                deprecation_warning("get_initial_state", "execution_plan")
                self.state = get_initial_state(self)
            else:
                self.state = {}
            if get_policy_class is None:
                self._policy = default_policy
            else:
                self._policy = get_policy_class(config)
            if before_init:
                before_init(self)
            # Creating all workers (excluding evaluation workers).
            if make_workers and not execution_plan:
                deprecation_warning("make_workers", "execution_plan")
                self.workers = make_workers(self, env_creator, self._policy,
                                            config)
            else:
                self.workers = self._make_workers(env_creator, self._policy,
                                                  config,
                                                  self.config["num_workers"])
            self.train_exec_impl = None
            self.optimizer = None
            self.execution_plan = execution_plan

            if make_policy_optimizer:
                deprecation_warning("make_policy_optimizer", "execution_plan")
                self.optimizer = make_policy_optimizer(self.workers, config)
            else:
                assert execution_plan is not None
                self.train_exec_impl = execution_plan(self.workers, config)
            if after_init:
                after_init(self)

        @override(Trainer)
        def _train(self):
            if self.train_exec_impl:
                return self._train_exec_impl()

            if before_train_step:
                deprecation_warning("before_train_step", "execution_plan")
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            optimizer_steps_this_iter = 0
            while True:
                fetches = self.optimizer.step()
                optimizer_steps_this_iter += 1
                if after_optimizer_step:
                    deprecation_warning("after_optimizer_step",
                                        "execution_plan")
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    break

            if collect_metrics_fn:
                deprecation_warning("collect_metrics_fn", "execution_plan")
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(
                optimizer_steps_this_iter=optimizer_steps_this_iter,
                timesteps_this_iter=self.optimizer.num_steps_sampled -
                prev_steps,
                info=res.get("info", {}))

            if after_train_result:
                deprecation_warning("after_train_result", "execution_plan")
                after_train_result(self, res)
            return res

        def _train_exec_impl(self):
            res = next(self.train_exec_impl)
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["trainer_state"] = self.state.copy()
            if self.train_exec_impl:
                state["train_exec_impl"] = (
                    self.train_exec_impl.shared_metrics.get().save())
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.state = state["trainer_state"].copy()
            if self.train_exec_impl:
                self.train_exec_impl.shared_metrics.get().restore(
                    state["train_exec_impl"])

    def with_updates(**overrides):
        """Build a copy of this trainer with the specified overrides.

        Arguments:
            overrides (dict): use this to override any of the arguments
                originally passed to build_trainer() for this policy.
        """
        return build_trainer(**dict(original_kwargs, **overrides))

    trainer_cls.with_updates = staticmethod(with_updates)
    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#12
0
                    "GPU": cf["num_gpus"],
                },
                {
                    # Different bundle (meaning: possibly different node)
                    # for your n "remote" envs (set remote_worker_envs=True).
                    "CPU": cf["num_envs_per_worker"],
                }
            ],
            strategy=config.get("placement_strategy", "PACK"))


# The modified Trainer class we will use. This is the exact same
# as a PPOTrainer, but with the additional default_resource_request
# override, telling tune that it's ok (not mandatory) to place our
# n remote envs on a different node (each env using 1 CPU).
PPOTrainerRemoteInference = add_mixins(PPOTrainer,
                                       [OverrideDefaultResourceRequest])

if __name__ == "__main__":
    args = get_cli_args()

    ray.init(num_cpus=6, local_mode=args.local_mode)

    config = {
        "env": "CartPole-v0",
        # Force sub-envs to be ray.actor.ActorHandles, so we can step
        # through them in parallel.
        "remote_worker_envs": True,
        # Set the number of CPUs used by the (local) worker, aka "driver"
        # to match the number of ray remote envs.
        "num_cpus_for_driver": args.num_envs_per_worker + 1,
        # Use a single worker (however, with n parallelized remote envs, maybe
示例#13
0
def build_eager_tf_policy(name,
                          loss_fn,
                          get_default_config=None,
                          postprocess_fn=None,
                          stats_fn=None,
                          optimizer_fn=None,
                          gradients_fn=None,
                          apply_gradients_fn=None,
                          grad_stats_fn=None,
                          extra_learn_fetches_fn=None,
                          extra_action_fetches_fn=None,
                          validate_spaces=None,
                          before_init=None,
                          before_loss_init=None,
                          after_init=None,
                          make_model=None,
                          action_sampler_fn=None,
                          action_distribution_fn=None,
                          mixins=None,
                          obs_include_prev_action_reward=True,
                          get_batch_divisibility_req=None):
    """Build an eager TF policy.

    An eager policy runs all operations in eager mode, which makes debugging
    much simpler, but has lower performance.

    You shouldn't need to call this directly. Rather, prefer to build a TF
    graph policy and use set {"framework": "tfe"} in the trainer config to have
    it automatically be converted to an eager policy.

    This has the same signature as build_tf_policy()."""

    base = add_mixins(Policy, mixins)

    class eager_policy_cls(base):
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)
            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            self.exploration = self._create_exploration()
            self._state_in = [
                tf.convert_to_tensor([s])
                for s in self.model.get_initial_state()
            ]
            input_dict = {
                SampleBatch.CUR_OBS:
                tf.convert_to_tensor(np.array([observation_space.sample()])),
                SampleBatch.PREV_ACTIONS:
                tf.convert_to_tensor(
                    [flatten_to_single_ndarray(action_space.sample())]),
                SampleBatch.PREV_REWARDS:
                tf.convert_to_tensor([0.]),
            }

            if action_distribution_fn:
                dist_inputs, self.dist_class, _ = action_distribution_fn(
                    self, self.model, input_dict[SampleBatch.CUR_OBS])
            else:
                self.model(input_dict, self._state_in,
                           tf.convert_to_tensor([1]))

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            self._initialize_loss_with_dummy_batch()
            self._loss_initialized = True

            if optimizer_fn:
                self._optimizer = optimizer_fn(self, config)
            else:
                self._optimizer = tf.keras.optimizers.Adam(config["lr"])

            if after_init:
                after_init(self, observation_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            assert tf.executing_eagerly()
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def learn_on_batch(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            self._apply_gradients(grads_and_vars)
            return stats

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def compute_gradients(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            grads = [g for g, v in grads_and_vars]
            return grads, stats

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def compute_actions(self,
                            obs_batch,
                            state_batches=None,
                            prev_action_batch=None,
                            prev_reward_batch=None,
                            info_batch=None,
                            episodes=None,
                            explore=None,
                            timestep=None,
                            **kwargs):

            explore = explore if explore is not None else \
                self.config["explore"]
            timestep = timestep if timestep is not None else \
                self.global_timestep

            # TODO: remove python side effect to cull sources of bugs.
            self._is_training = False
            self._state_in = state_batches

            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()

            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            n = input_dict[SampleBatch.CUR_OBS].shape[0]
            seq_lens = tf.ones(n, dtype=tf.int32)
            if obs_include_prev_action_reward:
                if prev_action_batch is not None:
                    input_dict[SampleBatch.PREV_ACTIONS] = \
                        tf.convert_to_tensor(prev_action_batch)
                if prev_reward_batch is not None:
                    input_dict[SampleBatch.PREV_REWARDS] = \
                        tf.convert_to_tensor(prev_reward_batch)

            # Use Exploration object.
            with tf.variable_creator_scope(_disallow_var_creation):
                if action_sampler_fn:
                    dist_inputs = None
                    state_out = []
                    actions, logp = self.action_sampler_fn(
                        self,
                        self.model,
                        input_dict[SampleBatch.CUR_OBS],
                        explore=explore,
                        timestep=timestep,
                        episodes=episodes)
                else:
                    # Exploration hook before each forward pass.
                    self.exploration.before_compute_actions(timestep=timestep,
                                                            explore=explore)

                    if action_distribution_fn:
                        dist_inputs, dist_class, state_out = \
                            action_distribution_fn(
                                self, self.model,
                                input_dict[SampleBatch.CUR_OBS],
                                explore=explore,
                                timestep=timestep,
                                is_training=False)
                    else:
                        dist_class = self.dist_class
                        dist_inputs, state_out = self.model(
                            input_dict, state_batches, seq_lens)

                    action_dist = dist_class(dist_inputs, self.model)

                    # Get the exploration action from the forward results.
                    actions, logp = self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

            # Add default and custom fetches.
            extra_fetches = {}
            # Action-logp and action-prob.
            if logp is not None:
                extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            # Custom extra fetches.
            if extra_action_fetches_fn:
                extra_fetches.update(extra_action_fetches_fn(self))

            # Update our global timestep by the batch size.
            self.global_timestep += len(obs_batch)

            return actions, state_out, extra_fetches

        @override(Policy)
        def compute_log_likelihoods(self,
                                    actions,
                                    obs_batch,
                                    state_batches=None,
                                    prev_action_batch=None,
                                    prev_reward_batch=None):
            if action_sampler_fn and action_distribution_fn is None:
                raise ValueError("Cannot compute log-prob/likelihood w/o an "
                                 "`action_distribution_fn` and a provided "
                                 "`action_sampler_fn`!")

            seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            if obs_include_prev_action_reward:
                input_dict.update({
                    SampleBatch.PREV_ACTIONS:
                    tf.convert_to_tensor(prev_action_batch),
                    SampleBatch.PREV_REWARDS:
                    tf.convert_to_tensor(prev_reward_batch),
                })

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if action_distribution_fn:
                dist_inputs, dist_class, _ = action_distribution_fn(
                    self,
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=False,
                    is_training=False)
                action_dist = dist_class(dist_inputs, self.model)
                log_likelihoods = action_dist.logp(actions)
            # Default log-likelihood calculation.
            else:
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)
                dist_class = self.dist_class

            action_dist = dist_class(dist_inputs, self.model)
            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods

        @override(Policy)
        def apply_gradients(self, gradients):
            self._apply_gradients(
                zip([(tf.convert_to_tensor(g) if g is not None else None)
                     for g in gradients], self.model.trainable_variables()))

        @override(Policy)
        def get_exploration_info(self):
            return _convert_to_numpy(self.exploration.get_info())

        @override(Policy)
        def get_weights(self, as_dict=False):
            variables = self.variables()
            if as_dict:
                return {v.name: v.numpy() for v in variables}
            return [v.numpy() for v in variables]

        @override(Policy)
        def set_weights(self, weights):
            variables = self.variables()
            assert len(weights) == len(variables), (len(weights),
                                                    len(variables))
            for v, w in zip(variables, weights):
                v.assign(w)

        @override(Policy)
        def get_state(self):
            state = {"_state": super().get_state()}
            state["_optimizer_variables"] = self._optimizer.variables()
            return state

        @override(Policy)
        def set_state(self, state):
            state = state.copy()  # shallow copy
            # Set optimizer vars first.
            optimizer_vars = state.pop("_optimizer_variables", None)
            if optimizer_vars and self._optimizer.variables():
                logger.warning(
                    "Cannot restore an optimizer's state for tf eager! Keras "
                    "is not able to save the v1.x optimizers (from "
                    "tf.compat.v1.train) since they aren't compatible with "
                    "checkpoints.")
                for opt_var, value in zip(self._optimizer.variables(),
                                          optimizer_vars):
                    opt_var.assign(value)
            # Then the Policy's (NN) weights.
            super().set_state(state["_state"])

        def variables(self):
            """Return the list of all savable variables for this policy."""
            return self.model.variables()

        @override(Policy)
        def is_recurrent(self):
            return len(self._state_in) > 0

        @override(Policy)
        def num_state_tensors(self):
            return len(self._state_in)

        @override(Policy)
        def get_initial_state(self):
            return self.model.get_initial_state()

        def get_session(self):
            return None  # None implies eager

        def get_placeholder(self, ph):
            raise ValueError(
                "get_placeholder() is not allowed in eager mode. Try using "
                "rllib.utils.tf_ops.make_tf_callable() to write "
                "functions that work in both graph and eager mode.")

        def loss_initialized(self):
            return self._loss_initialized

        @override(Policy)
        def export_model(self, export_dir):
            pass

        @override(Policy)
        def export_checkpoint(self, export_dir):
            pass

        def _get_is_training_placeholder(self):
            return tf.convert_to_tensor(self._is_training)

        def _apply_gradients(self, grads_and_vars):
            if apply_gradients_fn:
                apply_gradients_fn(self, self._optimizer, grads_and_vars)
            else:
                self._optimizer.apply_gradients(grads_and_vars)

        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            self._is_training = True

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                # TODO: set seq len and state-in properly
                state_in = []
                for i in range(self.num_state_tensors()):
                    state_in.append(samples["state_in_{}".format(i)])
                self._state_in = state_in

                self._seq_lens = None
                if len(state_in) > 0:
                    self._seq_lens = tf.ones(
                        samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32)
                    samples["seq_lens"] = self._seq_lens

                model_out, _ = self.model(samples, self._state_in,
                                          self._seq_lens)
                loss = loss_fn(self, self.model, self.dist_class, samples)

            variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper:
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats

        def _stats(self, outputs, samples, grads):

            fetches = {}
            if stats_fn:
                fetches[LEARNER_STATS_KEY] = {
                    k: v
                    for k, v in stats_fn(outputs, samples).items()
                }
            else:
                fetches[LEARNER_STATS_KEY] = {}

            if extra_learn_fetches_fn:
                fetches.update(
                    {k: v
                     for k, v in extra_learn_fetches_fn(self).items()})
            if grad_stats_fn:
                fetches.update({
                    k: v
                    for k, v in grad_stats_fn(self, samples, grads).items()
                })
            return fetches

        def _initialize_loss_with_dummy_batch(self):
            # Dummy forward pass to initialize any policy attributes, etc.
            dummy_batch = {
                SampleBatch.CUR_OBS:
                np.array([self.observation_space.sample()]),
                SampleBatch.NEXT_OBS:
                np.array([self.observation_space.sample()]),
                SampleBatch.DONES: np.array([False], dtype=np.bool),
                SampleBatch.REWARDS: np.array([0], dtype=np.float32),
            }
            if isinstance(self.action_space, (Dict, Tuple)):
                dummy_batch[SampleBatch.ACTIONS] = [
                    flatten_to_single_ndarray(self.action_space.sample())
                ]
            else:
                dummy_batch[SampleBatch.ACTIONS] = tf.nest.map_structure(
                    lambda c: np.array([c]), self.action_space.sample())

            if obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    dummy_batch[SampleBatch.ACTIONS],
                    SampleBatch.PREV_REWARDS:
                    dummy_batch[SampleBatch.REWARDS],
                })
            for i, h in enumerate(self._state_in):
                dummy_batch["state_in_{}".format(i)] = h
                dummy_batch["state_out_{}".format(i)] = h

            if self._state_in:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)

            # Convert everything to tensors.
            dummy_batch = tf.nest.map_structure(tf1.convert_to_tensor,
                                                dummy_batch)

            # for IMPALA which expects a certain sample batch size.
            def tile_to(tensor, n):
                return tf.tile(tensor,
                               [n] + [1 for _ in tensor.shape.as_list()[1:]])

            if get_batch_divisibility_req:
                dummy_batch = tf.nest.map_structure(
                    lambda c: tile_to(c, get_batch_divisibility_req(self)),
                    dummy_batch)
            i = 0
            self._state_in = []
            while "state_in_{}".format(i) in dummy_batch:
                self._state_in.append(dummy_batch["state_in_{}".format(i)])
                i += 1

            # Execute a forward pass to get self.action_dist etc initialized,
            # and also obtain the extra action fetches
            _, _, fetches = self.compute_actions(
                dummy_batch[SampleBatch.CUR_OBS],
                self._state_in,
                dummy_batch.get(SampleBatch.PREV_ACTIONS),
                dummy_batch.get(SampleBatch.PREV_REWARDS),
                explore=False)
            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0
            dummy_batch.update(fetches)

            postprocessed_batch = self.postprocess_trajectory(
                SampleBatch(dummy_batch))

            # model forward pass for the loss (needed after postprocess to
            # overwrite any tensor state from that call)
            self.model.from_batch(dummy_batch)

            postprocessed_batch = tf.nest.map_structure(
                lambda c: tf.convert_to_tensor(c), postprocessed_batch.data)

            loss_fn(self, self.model, self.dist_class, postprocessed_batch)
            if stats_fn:
                stats_fn(self, postprocessed_batch)

        @classmethod
        def with_tracing(cls):
            return traced_eager_policy(cls)

    eager_policy_cls.__name__ = name + "_eager"
    eager_policy_cls.__qualname__ = name + "_eager"
    return eager_policy_cls
示例#14
0
def build_tf_policy(
        name: str,
        *,
        loss_fn: Callable[
            [Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
            Union[TensorType, List[TensorType]]],
        get_default_config: Optional[Callable[[None],
                                              TrainerConfigDict]] = None,
        postprocess_fn: Optional[Callable[[
            Policy, SampleBatch, Optional[Dict[
                AgentID, SampleBatch]], Optional["Episode"]
        ], SampleBatch]] = None,
        stats_fn: Optional[Callable[[Policy, SampleBatch],
                                    Dict[str, TensorType]]] = None,
        optimizer_fn: Optional[
            Callable[[Policy, TrainerConfigDict],
                     "tf.keras.optimizers.Optimizer"]] = None,
        compute_gradients_fn: Optional[
            Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType],
                     ModelGradients]] = None,
        apply_gradients_fn: Optional[
            Callable[[Policy, "tf.keras.optimizers.Optimizer", ModelGradients],
                     "tf.Operation"]] = None,
        grad_stats_fn: Optional[Callable[[Policy, SampleBatch, ModelGradients],
                                         Dict[str, TensorType]]] = None,
        extra_action_out_fn: Optional[Callable[[Policy],
                                               Dict[str, TensorType]]] = None,
        extra_learn_fetches_fn: Optional[Callable[[Policy],
                                                  Dict[str,
                                                       TensorType]]] = None,
        validate_spaces: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_loss_init: Optional[Callable[
            [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            None]] = None,
        after_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        make_model: Optional[Callable[
            [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            ModelV2]] = None,
        action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]]] = None,
        action_distribution_fn: Optional[
            Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, type, List[TensorType]]]] = None,
        mixins: Optional[List[type]] = None,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
        # Deprecated args.
        obs_include_prev_action_reward=DEPRECATED_VALUE,
        extra_action_fetches_fn=None,  # Use `extra_action_out_fn`.
        gradients_fn=None,  # Use `compute_gradients_fn`.
) -> Type[DynamicTFPolicy]:
    """Helper function for creating a dynamic tf policy at runtime.

    Functions will be run in this order to initialize the policy:
        1. Placeholder setup: postprocess_fn
        2. Loss init: loss_fn, stats_fn
        3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
                           grad_stats_fn

    This means that you can e.g., depend on any policy attributes created in
    the running of `loss_fn` in later functions such as `stats_fn`.

    In eager mode, the following functions will be run repeatedly on each
    eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
    and grad_stats_fn.

    This means that these functions should not define any variables internally,
    otherwise they will fail in eager mode execution. Variable should only
    be created in make_model (if defined).

    Args:
        name (str): Name of the policy (e.g., "PPOTFPolicy").
        loss_fn (Callable[[
            Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
            Union[TensorType, List[TensorType]]]): Callable for calculating a
            loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
            Optional callable for post-processing experience batches (called
            after the parent class' `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            TF tensors to fetch given the policy and batch input tensors. If
            None, will not compute any stats.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "tf.keras.optimizers.Optimizer"]]): Optional callable that returns
            a tf.Optimizer given the policy and config. If None, will call
            the base class' `optimizer()` method instead (which returns a
            tf1.train.AdamOptimizer).
        compute_gradients_fn (Optional[Callable[[Policy,
            "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
            Optional callable that returns a list of gradients. If None,
            this defaults to optimizer.compute_gradients([loss]).
        apply_gradients_fn (Optional[Callable[[Policy,
            "tf.keras.optimizers.Optimizer", ModelGradients],
            "tf.Operation"]]): Optional callable that returns an apply
            gradients op given policy, tf-optimizer, and grads_and_vars. If
            None, will call the base class' `build_apply_op()` method instead.
        grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            TF fetches given the policy, batch input, and gradient tensors. If
            None, will not collect any gradient stats.
        extra_action_out_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns
            a dict of TF fetches given the policy object. If None, will not
            perform any extra fetches.
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra values to fetch and return when learning on a batch. If None,
            will call the base class' `extra_compute_grad_fetches()` method
            instead.
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check
            the spaces for correctness. If None, no spaces checking will be
            done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of policy init that takes the same arguments as the
            policy constructor. If None, this step will be skipped.
        before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run prior to loss init. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the end of
            policy init. If None, this step will be skipped.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that returns a ModelV2 object.
            All policy variables should be created in this function. If None,
            a default ModelV2 object will be created.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): A callable returning a sampled
            action and its log-likelihood given observation and state inputs.
            If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType],
            Tuple[TensorType, type, List[TensorType]]]]): Optional callable
            returning distribution inputs (parameters), a dist-class to
            generate an action distribution object from, and internal-state
            outputs (or an empty list if not applicable). If None, will either
            use `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the so parameterized action distribution.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the DynamicTFPolicy class.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.

    Returns:
        Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
            specified args.
    """
    original_kwargs = locals().copy()
    base = add_mixins(DynamicTFPolicy, mixins)

    if obs_include_prev_action_reward != DEPRECATED_VALUE:
        deprecation_warning(old="obs_include_prev_action_reward", error=False)

    if extra_action_fetches_fn is not None:
        deprecation_warning(old="extra_action_fetches_fn",
                            new="extra_action_out_fn",
                            error=False)
        extra_action_out_fn = extra_action_fetches_fn

    if gradients_fn is not None:
        deprecation_warning(old="gradients_fn",
                            new="compute_gradients_fn",
                            error=False)
        compute_gradients_fn = gradients_fn

    class policy_cls(base):
        def __init__(self,
                     obs_space,
                     action_space,
                     config,
                     existing_model=None,
                     existing_inputs=None):
            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, config)

            if before_init:
                before_init(self, obs_space, action_space, config)

            def before_loss_init_wrapper(policy, obs_space, action_space,
                                         config):
                if before_loss_init:
                    before_loss_init(policy, obs_space, action_space, config)

                if extra_action_out_fn is None or policy._is_tower:
                    extra_action_fetches = {}
                else:
                    extra_action_fetches = extra_action_out_fn(policy)

                if hasattr(policy, "_extra_action_fetches"):
                    policy._extra_action_fetches.update(extra_action_fetches)
                else:
                    policy._extra_action_fetches = extra_action_fetches

            DynamicTFPolicy.__init__(
                self,
                obs_space=obs_space,
                action_space=action_space,
                config=config,
                loss_fn=loss_fn,
                stats_fn=stats_fn,
                grad_stats_fn=grad_stats_fn,
                before_loss_init=before_loss_init_wrapper,
                make_model=make_model,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                existing_inputs=existing_inputs,
                existing_model=existing_model,
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @override(TFPolicy)
        def optimizer(self):
            if optimizer_fn:
                optimizers = optimizer_fn(self, self.config)
            else:
                optimizers = base.optimizer(self)
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)

            # No optimizers produced -> Return None.
            if not optimizers:
                return None
            # New API: Allow more than one optimizer to be returned.
            # -> Return list.
            elif self.config["_tf_policy_handles_more_than_one_loss"]:
                return optimizers
            # Old API: Return a single LocalOptimizer.
            else:
                return optimizers[0]

        @override(TFPolicy)
        def gradients(self, optimizer, loss):
            optimizers = force_list(optimizer)
            losses = force_list(loss)

            if compute_gradients_fn:
                # New API: Allow more than one optimizer -> Return a list of
                # lists of gradients.
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    return compute_gradients_fn(self, optimizers, losses)
                # Old API: Return a single List of gradients.
                else:
                    return compute_gradients_fn(self, optimizers[0], losses[0])
            else:
                return base.gradients(self, optimizers, losses)

        @override(TFPolicy)
        def build_apply_op(self, optimizer, grads_and_vars):
            if apply_gradients_fn:
                return apply_gradients_fn(self, optimizer, grads_and_vars)
            else:
                return base.build_apply_op(self, optimizer, grads_and_vars)

        @override(TFPolicy)
        def extra_compute_action_fetches(self):
            return dict(base.extra_compute_action_fetches(self),
                        **self._extra_action_fetches)

        @override(TFPolicy)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                # TODO: (sven) in torch, extra_learn_fetches do not exist.
                #  Hence, things like td_error are returned by the stats_fn
                #  and end up under the LEARNER_STATS_KEY. We should
                #  change tf to do this as well. However, this will confilct
                #  the handling of LEARNER_STATS_KEY inside the multi-GPU
                #  train op.
                # Auto-add empty learner stats dict if needed.
                return dict({LEARNER_STATS_KEY: {}},
                            **extra_learn_fetches_fn(self))
            else:
                return base.extra_compute_grad_fetches(self)

    def with_updates(**overrides):
        """Allows creating a TFPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_tf_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new TFPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_tf_policy(**dict(original_kwargs, **overrides))

    def as_eager():
        return eager_tf_policy.build_eager_tf_policy(**original_kwargs)

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.as_eager = staticmethod(as_eager)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#15
0
def build_torch_policy(name,
                       loss_fn,
                       get_default_config=None,
                       stats_fn=None,
                       postprocess_fn=None,
                       extra_action_out_fn=None,
                       extra_grad_process_fn=None,
                       optimizer_fn=None,
                       before_init=None,
                       after_init=None,
                       make_model_and_action_dist=None,
                       mixins=None):
    """Helper function for creating a torch policy at runtime.

    Arguments:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        loss_fn (func): function that returns a loss tensor the policy,
            and dict of experience tensor placeholders
        get_default_config (func): optional function that returns the default
            config to merge with any overrides
        stats_fn (func): optional function that returns a dict of
            values given the policy and batch input tensors
        postprocess_fn (func): optional experience postprocessing function
            that takes the same args as Policy.postprocess_trajectory()
        extra_action_out_fn (func): optional function that returns
            a dict of extra values to include in experiences
        extra_grad_process_fn (func): optional function that is called after
            gradients are computed and returns processing info
        optimizer_fn (func): optional function that returns a torch optimizer
            given the policy and config
        before_init (func): optional function to run at the beginning of
            policy init that takes the same arguments as the policy constructor
        after_init (func): optional function to run at the end of policy init
            that takes the same arguments as the policy constructor
        make_model_and_action_dist (func): optional func that takes the same
            arguments as policy init and returns a tuple of model instance and
            torch action distribution class. If not specified, the default
            model and action dist from the catalog will be used
        mixins (list): list of any class mixins for the returned policy class.
            These mixins will be applied in order and will have higher
            precedence than the TorchPolicy class

    Returns:
        a TorchPolicy instance that uses the specified args
    """

    original_kwargs = locals().copy()
    base = add_mixins(TorchPolicy, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], torch=True)
                self.model = ModelCatalog.get_model_v2(
                    obs_space,
                    action_space,
                    logit_dim,
                    self.config["model"],
                    framework="torch")

            TorchPolicy.__init__(self, obs_space, action_space, self.model,
                                 loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            if not postprocess_fn:
                return sample_batch
            return postprocess_fn(self, sample_batch, other_agent_batches,
                                  episode)

        @override(TorchPolicy)
        def extra_grad_process(self):
            if extra_grad_process_fn:
                return extra_grad_process_fn(self)
            else:
                return TorchPolicy.extra_grad_process(self)

        @override(TorchPolicy)
        def extra_action_out(self, input_dict, state_batches, model):
            if extra_action_out_fn:
                return extra_action_out_fn(self, input_dict, state_batches,
                                           model)
            else:
                return TorchPolicy.extra_action_out(self, input_dict,
                                                    state_batches, model)

        @override(TorchPolicy)
        def optimizer(self):
            if optimizer_fn:
                return optimizer_fn(self, self.config)
            else:
                return TorchPolicy.optimizer(self)

        @override(TorchPolicy)
        def extra_grad_info(self, batch_tensors):
            if stats_fn:
                return stats_fn(self, batch_tensors)
            else:
                return TorchPolicy.extra_grad_info(self, batch_tensors)

    @staticmethod
    def with_updates(**overrides):
        return build_torch_policy(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = with_updates
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#16
0
def build_policy_class(
    name: str,
    framework: str,
    *,
    loss_fn: Optional[Callable[
        [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
        Union[TensorType, List[TensorType]]]],
    get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
    stats_fn: Optional[Callable[[Policy, SampleBatch],
                                Dict[str, TensorType]]] = None,
    postprocess_fn: Optional[Callable[[
        Policy, SampleBatch, Optional[Dict[
            Any, SampleBatch]], Optional["MultiAgentEpisode"]
    ], SampleBatch]] = None,
    extra_action_out_fn: Optional[Callable[[
        Policy, Dict[
            str,
            TensorType], List[TensorType], ModelV2, TorchDistributionWrapper
    ], Dict[str, TensorType]]] = None,
    extra_grad_process_fn: Optional[
        Callable[[Policy, "torch.optim.Optimizer", TensorType],
                 Dict[str, TensorType]]] = None,
    # TODO: (sven) Replace "fetches" with "process".
    extra_learn_fetches_fn: Optional[Callable[[Policy],
                                              Dict[str, TensorType]]] = None,
    optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
                                    "torch.optim.Optimizer"]] = None,
    validate_spaces: Optional[Callable[
        [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
    before_init: Optional[Callable[
        [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
    before_loss_init: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        None]] = None,
    after_init: Optional[Callable[
        [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
    _after_loss_init: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        None]] = None,
    action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                         Tuple[TensorType,
                                               TensorType]]] = None,
    action_distribution_fn: Optional[
        Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                 Tuple[TensorType, type, List[TensorType]]]] = None,
    make_model: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        ModelV2]] = None,
    make_model_and_action_dist: Optional[Callable[
        [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
        Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None,
    apply_gradients_fn: Optional[Callable[[Policy, "torch.optim.Optimizer"],
                                          None]] = None,
    mixins: Optional[List[type]] = None,
    view_requirements_fn: Optional[Callable[[Policy],
                                            Dict[str,
                                                 ViewRequirement]]] = None,
    get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
    """Helper function for creating a new Policy class at runtime.

    Supports frameworks JAX and PyTorch.

    Args:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        framework (str): Either "jax" or "torch".
        loss_fn (Optional[Callable[[Policy, ModelV2,
            Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
            List[TensorType]]]]): Callable that returns a loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]],
            SampleBatch]]): Optional callable for post-processing experience
            batches (called after the super's `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            values given the policy and training batch. If None,
            will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
            used for logging (e.g. in TensorBoard).
        extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
            List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
            TensorType]]]): Optional callable that returns a dict of extra
            values to include in experiences. If None, no extra computations
            will be performed.
        extra_grad_process_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
            Optional callable that is called after gradients are computed and
            returns a processing info dict. If None, will call the
            `TorchPolicy.extra_grad_process()` method instead.
        # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra tensors from the policy after loss evaluation. If None,
            will call the `TorchPolicy.extra_compute_grad_fetches()` method
            instead.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "torch.optim.Optimizer"]]): Optional callable that returns a
            torch optimizer given the policy and config. If None, will call
            the `TorchPolicy.optimizer()` method instead (which returns a
            torch Adam optimizer).
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check for
            correctness. If None, no spaces checking will be done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of `Policy.__init__` that takes the same arguments as
            the Policy constructor. If None, this step will be skipped.
        before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run prior to loss init. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
            instead.
        _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run after the loss init. If None, this step will be skipped.
            This will be deprecated at some point and renamed into `after_init`
            to match `build_tf_policy()` behavior.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): Optional callable returning a
            sampled action and its log-likelihood given some (obs and state)
            inputs. If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType], Tuple[TensorType,
            Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
            that takes the Policy, Model, the observation batch, an
            explore-flag, a timestep, and an is_training flag and returns a
            tuple of a) distribution inputs (parameters), b) a dist-class to
            generate an action distribution object from, and c) internal-state
            outputs (empty list if not applicable). If None, will either use
            `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the parameterized action distribution.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that takes the same arguments as Policy.__init__ and returns a
            model instance. The distribution class will be determined
            automatically. Note: Only one of `make_model` or
            `make_model_and_action_dist` should be provided. If both are None,
            a default Model will be created.
        make_model_and_action_dist (Optional[Callable[[Policy,
            gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
            callable that takes the same arguments as Policy.__init__ and
            returns a tuple of model instance and torch action distribution
            class.
            Note: Only one of `make_model` or `make_model_and_action_dist`
            should be provided. If both are None, a default Model will be
            created.
        apply_gradients_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer"], None]]): Optional callable that
            takes a grads list and applies these to the Model's parameters.
            If None, will call the `TorchPolicy.apply_gradients()` method
            instead.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the TorchPolicy class.
        view_requirements_fn (Optional[Callable[[Policy],
            Dict[str, ViewRequirement]]]): An optional callable to retrieve
            additional train view requirements for this policy.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.

    Returns:
        Type[TorchPolicy]: TorchPolicy child class constructed from the
            specified args.
    """

    original_kwargs = locals().copy()
    parent_cls = TorchPolicy
    base = add_mixins(parent_cls, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            # Set up the config from possible default-config fn and given
            # config arg.
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            # Set the DL framework for this Policy.
            self.framework = self.config["framework"] = framework

            # Validate observation- and action-spaces.
            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            # Do some pre-initialization steps.
            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework=framework)
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework=framework)
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework=framework)

            # Make sure, we passed in a correct Model factory.
            model_cls = TorchModelV2 if framework == "torch" else JAXModelV2
            assert isinstance(self.model, model_cls), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            # Call the framework-specific Policy constructor.
            self.parent_cls = parent_cls
            self.parent_cls.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            # Update this Policy's ViewRequirements (if function given).
            if callable(view_requirements_fn):
                self.view_requirements.update(view_requirements_fn(self))
            # Merge Model's view requirements into Policy's.
            self.view_requirements.update(
                self.model.inference_view_requirements)

            _before_loss_init = before_loss_init or after_init
            if _before_loss_init:
                _before_loss_init(self, self.observation_space,
                                  self.action_space, config)

            # Perform test runs through postprocessing- and loss functions.
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )

            if _after_loss_init:
                _after_loss_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak
            # in torch (issue #6962).
            with self._no_grad_context():
                # Call super's postprocess_trajectory first.
                sample_batch = super().postprocess_trajectory(
                    sample_batch, other_agent_batches, episode)
                if postprocess_fn:
                    return postprocess_fn(self, sample_batch,
                                          other_agent_batches, episode)

                return sample_batch

        @override(parent_cls)
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return parent_cls.extra_grad_process(self, optimizer, loss)

        @override(parent_cls)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                fetches = convert_to_non_torch_type(
                    extra_learn_fetches_fn(self))
                # Auto-add empty learner stats dict if needed.
                return dict({LEARNER_STATS_KEY: {}}, **fetches)
            else:
                return parent_cls.extra_compute_grad_fetches(self)

        @override(parent_cls)
        def apply_gradients(self, gradients):
            if apply_gradients_fn:
                apply_gradients_fn(self, gradients)
            else:
                parent_cls.apply_gradients(self, gradients)

        @override(parent_cls)
        def extra_action_out(self, input_dict, state_batches, model,
                             action_dist):
            with self._no_grad_context():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(self, input_dict,
                                                     state_batches, model,
                                                     action_dist)
                else:
                    stats_dict = parent_cls.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return self._convert_to_non_torch_type(stats_dict)

        @override(parent_cls)
        def optimizer(self):
            if optimizer_fn:
                optimizers = optimizer_fn(self, self.config)
            else:
                optimizers = parent_cls.optimizer(self)
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)
            return optimizers

        @override(parent_cls)
        def extra_grad_info(self, train_batch):
            with self._no_grad_context():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = self.parent_cls.extra_grad_info(
                        self, train_batch)
                return self._convert_to_non_torch_type(stats_dict)

        def _no_grad_context(self):
            if self.framework == "torch":
                return torch.no_grad()
            return NullContextManager()

        def _convert_to_non_torch_type(self, data):
            if self.framework == "torch":
                return convert_to_non_torch_type(data)
            return data

    def with_updates(**overrides):
        """Creates a Torch|JAXPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_torch_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new Torch|JAXPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_policy_class(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#17
0
def build_eager_tf_policy(name,
                          loss_fn,
                          get_default_config=None,
                          postprocess_fn=None,
                          stats_fn=None,
                          optimizer_fn=None,
                          gradients_fn=None,
                          apply_gradients_fn=None,
                          grad_stats_fn=None,
                          extra_learn_fetches_fn=None,
                          extra_action_fetches_fn=None,
                          before_init=None,
                          before_loss_init=None,
                          after_init=None,
                          make_model=None,
                          action_sampler_fn=None,
                          mixins=None,
                          obs_include_prev_action_reward=True,
                          get_batch_divisibility_req=None):
    """Build an eager TF policy.

    An eager policy runs all operations in eager mode, which makes debugging
    much simpler, but is lower performance.

    You shouldn't need to call this directly. Rather, prefer to build a TF
    graph policy and use set {"eager": true} in the trainer config to have
    it automatically be converted to an eager policy.

    This has the same signature as build_tf_policy()."""

    base = add_mixins(Policy, mixins)

    class eager_policy_cls(base):
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            Policy.__init__(self, observation_space, action_space, config)
            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            if get_default_config:
                config = dict(get_default_config(), **config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config

            if action_sampler_fn:
                if not make_model:
                    raise ValueError(
                        "make_model is required if action_sampler_fn is given")
                self.dist_class = None
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework="tf",
                )

            self.model(
                {
                    SampleBatch.CUR_OBS:
                    tf.convert_to_tensor(np.array([observation_space.sample()
                                                   ])),
                    SampleBatch.PREV_ACTIONS:
                    tf.convert_to_tensor(
                        [_flatten_action(action_space.sample())]),
                    SampleBatch.PREV_REWARDS:
                    tf.convert_to_tensor([0.]),
                }, [
                    tf.convert_to_tensor([s])
                    for s in self.model.get_initial_state()
                ], tf.convert_to_tensor([1]))

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            self._initialize_loss_with_dummy_batch()
            self._loss_initialized = True

            if optimizer_fn:
                self._optimizer = optimizer_fn(self, config)
            else:
                self._optimizer = tf.train.AdamOptimizer(config["lr"])

            if after_init:
                after_init(self, observation_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   samples,
                                   other_agent_batches=None,
                                   episode=None):
            assert tf.executing_eagerly()
            if postprocess_fn:
                return postprocess_fn(self, samples, other_agent_batches,
                                      episode)
            else:
                return samples

        @override(Policy)
        def learn_on_batch(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            self._apply_gradients(grads_and_vars)
            return stats

        @override(Policy)
        def compute_gradients(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            grads = [g for g, v in grads_and_vars]
            grads = [(g.numpy() if g is not None else None) for g in grads]
            return grads, stats

        @override(Policy)
        def compute_actions(self,
                            obs_batch,
                            state_batches,
                            prev_action_batch=None,
                            prev_reward_batch=None,
                            info_batch=None,
                            episodes=None,
                            **kwargs):

            assert tf.executing_eagerly()
            self._is_training = False

            self._seq_lens = tf.ones(len(obs_batch))
            self._input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.convert_to_tensor(False),
            }
            if obs_include_prev_action_reward:
                self._input_dict.update({
                    SampleBatch.PREV_ACTIONS:
                    tf.convert_to_tensor(prev_action_batch),
                    SampleBatch.PREV_REWARDS:
                    tf.convert_to_tensor(prev_reward_batch),
                })
            self._state_in = state_batches
            with tf.variable_creator_scope(_disallow_var_creation):
                model_out, state_out = self.model(self._input_dict,
                                                  state_batches,
                                                  self._seq_lens)

            if self.dist_class:
                action_dist = self.dist_class(model_out, self.model)
                action = action_dist.sample().numpy()
                logp = action_dist.sampled_action_logp()
            else:
                action, logp = action_sampler_fn(self, self.model,
                                                 self._input_dict,
                                                 self.observation_space,
                                                 self.action_space,
                                                 self.config)
                action = action.numpy()

            fetches = {}
            if logp is not None:
                fetches.update({
                    ACTION_PROB: tf.exp(logp).numpy(),
                    ACTION_LOGP: logp.numpy(),
                })
            if extra_action_fetches_fn:
                fetches.update(extra_action_fetches_fn(self))
            return action, state_out, fetches

        @override(Policy)
        def apply_gradients(self, gradients):
            self._apply_gradients(
                zip([(tf.convert_to_tensor(g) if g is not None else None)
                     for g in gradients], self.model.trainable_variables()))

        @override(Policy)
        def get_weights(self):
            variables = self.model.variables()
            return [v.numpy() for v in variables]

        @override(Policy)
        def set_weights(self, weights):
            variables = self.model.variables()
            assert len(weights) == len(variables), (len(weights),
                                                    len(variables))
            for v, w in zip(variables, weights):
                v.assign(w)

        def is_recurrent(self):
            return len(self._state_in) > 0

        def num_state_tensors(self):
            return len(self._state_in)

        def get_session(self):
            return None  # None implies eager

        def get_placeholder(self, ph):
            raise ValueError(
                "get_placeholder() is not allowed in eager mode. Try using "
                "rllib.utils.tf_ops.make_tf_callable() to write "
                "functions that work in both graph and eager mode.")

        def loss_initialized(self):
            return self._loss_initialized

        def _get_is_training_placeholder(self):
            return tf.convert_to_tensor(self._is_training)

        def _apply_gradients(self, grads_and_vars):
            if apply_gradients_fn:
                apply_gradients_fn(self, self._optimizer, grads_and_vars)
            else:
                self._optimizer.apply_gradients(grads_and_vars)

        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            self._is_training = True

            samples = {
                k: tf.convert_to_tensor(v)
                for k, v in samples.items() if v.dtype != np.object
            }

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                # TODO: set seq len and state in properly
                self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS]))
                self._state_in = []
                model_out, _ = self.model(samples, self._state_in,
                                          self._seq_lens)
                loss = loss_fn(self, self.model, self.dist_class, samples)

            variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper(object):
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats

        def _stats(self, outputs, samples, grads):
            assert tf.executing_eagerly()
            fetches = {}
            if stats_fn:
                fetches[LEARNER_STATS_KEY] = {
                    k: v.numpy()
                    for k, v in stats_fn(outputs, samples).items()
                }
            else:
                fetches[LEARNER_STATS_KEY] = {}
            if extra_learn_fetches_fn:
                fetches.update({
                    k: v.numpy()
                    for k, v in extra_learn_fetches_fn(self).items()
                })
            if grad_stats_fn:
                fetches.update({
                    k: v.numpy()
                    for k, v in grad_stats_fn(self, samples, grads).items()
                })
            return fetches

        def _initialize_loss_with_dummy_batch(self):
            # Dummy forward pass to initialize any policy attributes, etc.
            action_dtype, action_shape = ModelCatalog.get_action_shape(
                self.action_space)
            dummy_batch = {
                SampleBatch.CUR_OBS:
                tf.convert_to_tensor(
                    np.array([self.observation_space.sample()])),
                SampleBatch.NEXT_OBS:
                tf.convert_to_tensor(
                    np.array([self.observation_space.sample()])),
                SampleBatch.DONES:
                tf.convert_to_tensor(np.array([False], dtype=np.bool)),
                SampleBatch.ACTIONS:
                tf.convert_to_tensor(
                    np.zeros((1, ) + action_shape[1:],
                             dtype=action_dtype.as_numpy_dtype())),
                SampleBatch.REWARDS:
                tf.convert_to_tensor(np.array([0], dtype=np.float32)),
            }
            if obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    dummy_batch[SampleBatch.ACTIONS],
                    SampleBatch.PREV_REWARDS:
                    dummy_batch[SampleBatch.REWARDS],
                })
            state_init = self.get_initial_state()
            state_batches = []
            for i, h in enumerate(state_init):
                dummy_batch["state_in_{}".format(i)] = tf.convert_to_tensor(
                    np.expand_dims(h, 0))
                dummy_batch["state_out_{}".format(i)] = tf.convert_to_tensor(
                    np.expand_dims(h, 0))
                state_batches.append(tf.convert_to_tensor(np.expand_dims(h,
                                                                         0)))
            if state_init:
                dummy_batch["seq_lens"] = tf.convert_to_tensor(
                    np.array([1], dtype=np.int32))

            # for IMPALA which expects a certain sample batch size
            def tile_to(tensor, n):
                return tf.tile(tensor,
                               [n] + [1 for _ in tensor.shape.as_list()[1:]])

            if get_batch_divisibility_req:
                dummy_batch = {
                    k: tile_to(v, get_batch_divisibility_req(self))
                    for k, v in dummy_batch.items()
                }

            # Execute a forward pass to get self.action_dist etc initialized,
            # and also obtain the extra action fetches
            _, _, fetches = self.compute_actions(
                dummy_batch[SampleBatch.CUR_OBS], state_batches,
                dummy_batch.get(SampleBatch.PREV_ACTIONS),
                dummy_batch.get(SampleBatch.PREV_REWARDS))
            dummy_batch.update(fetches)

            postprocessed_batch = self.postprocess_trajectory(
                SampleBatch(dummy_batch))

            # model forward pass for the loss (needed after postprocess to
            # overwrite any tensor state from that call)
            self.model.from_batch(dummy_batch)

            postprocessed_batch = {
                k: tf.convert_to_tensor(v)
                for k, v in postprocessed_batch.items()
            }

            loss_fn(self, self.model, self.dist_class, postprocessed_batch)
            if stats_fn:
                stats_fn(self, postprocessed_batch)

    eager_policy_cls.__name__ = name + "_eager"
    eager_policy_cls.__qualname__ = name + "_eager"
    return eager_policy_cls
示例#18
0
def build_trainer(name,
                  default_policy,
                  default_config=None,
                  validate_config=None,
                  get_initial_state=None,
                  get_policy_class=None,
                  before_init=None,
                  make_workers=None,
                  make_policy_optimizer=None,
                  after_init=None,
                  before_train_step=None,
                  after_optimizer_step=None,
                  after_train_result=None,
                  collect_metrics_fn=None,
                  before_evaluate_fn=None,
                  mixins=None,
                  training_pipeline=None):
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_initial_state, get_policy
        2. Worker setup: before_init, make_workers, make_policy_optimizer
        3. Post setup: after_init

    Arguments:
        name (str): name of the trainer (e.g., "PPO")
        default_policy (cls): the default Policy class to use
        default_config (Optional[dict]): The default config dict of the
            algorithm. If None, uses the Trainer default config.
        validate_config (Optional[callable]): Optional callback that checks a
            given config for correctness. It may mutate the config as needed.
        get_initial_state (Optional[callable]): Optional callable that returns
            the initial state dict given the trainer instance as an argument.
            The state dict must be serializable so that it can be checkpointed,
            and will be available as the `trainer.state` variable.
        get_policy_class (Optional[callable]): Optional callable that takes a
            Trainer config and returns the policy class to override the default
            with.
        before_init (Optional[callable]): Optional callable to run at the start
            of trainer init that takes the trainer instance as argument.
        make_workers (Optional[callable]): Override the default method that
            creates rollout workers. This takes in (trainer, env_creator,
            policy, config) as args.
        make_policy_optimizer (Optional[callable]): Optional callable that
            returns a PolicyOptimizer instance given (WorkerSet, config).
        after_init (Optional[callable]): Optional callable to run at the end of
            trainer init that takes the trainer instance as argument.
        before_train_step (Optional[callable]): Optional callable to run before
            each train() call. It takes the trainer instance as an argument.
        after_optimizer_step (Optional[callable]): Optional callable to run
            after each step() call to the policy optimizer. It takes the
            trainer instance and the policy gradient fetches as arguments.
        after_train_result (Optional[callable]): Optional callable to run at
            the end of each train() call. It takes the trainer instance and
            result dict as arguments, and may mutate the result dict as needed.
        collect_metrics_fn (Optional[callable]): Optional callable to override
            the default method used to collect metrics. Takes the trainer
            instance as argumnt.
        before_evaluate_fn (Optional[callable]): Optional callable to run
            before evaluation. Takes the trainer instance as argument.
        mixins (Optional[List[class]]): Optional list of mixin class(es) for
            the returned trainer class. These mixins will be applied in order
            and will have higher precedence than the Trainer class.
        training_pipeline (Optional[callable]): Experimental support for custom
            training pipelines. This overrides `make_policy_optimizer`.

    Returns:
        a Trainer instance that uses the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)

            if get_initial_state:
                self.state = get_initial_state(self)
            else:
                self.state = {}

            # Override default policy if `get_policy_class` is provided.
            if get_policy_class is not None:
                self._policy = get_policy_class(config)

            if before_init:
                before_init(self)

            # Creating all workers (excluding evaluation workers).
            if make_workers:
                self.workers = make_workers(self, env_creator, self._policy,
                                            config)
            else:
                self.workers = self._make_workers(env_creator, self._policy,
                                                  config,
                                                  self.config["num_workers"])
            self.train_pipeline = None
            self.optimizer = None

            if training_pipeline and (self.config["use_pipeline_impl"] or
                                      "RLLIB_USE_PIPELINE_IMPL" in os.environ):
                logger.warning("Using experimental pipeline based impl.")
                self.train_pipeline = training_pipeline(self.workers, config)
            elif make_policy_optimizer:
                self.optimizer = make_policy_optimizer(self.workers, config)
            else:
                optimizer_config = dict(
                    config["optimizer"],
                    **{"train_batch_size": config["train_batch_size"]})
                self.optimizer = SyncSamplesOptimizer(self.workers,
                                                      **optimizer_config)
            if after_init:
                after_init(self)

        @override(Trainer)
        def _train(self):
            if self.train_pipeline:
                return self._train_pipeline()

            if before_train_step:
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            while True:
                fetches = self.optimizer.step()
                if after_optimizer_step:
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    break

            if collect_metrics_fn:
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                       prev_steps,
                       info=res.get("info", {}))

            if after_train_result:
                after_train_result(self, res)
            return res

        def _train_pipeline(self):
            if before_train_step:
                before_train_step(self)
            res = next(self.train_pipeline)
            if after_train_result:
                after_train_result(self, res)
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["trainer_state"] = self.state.copy()
            if self.train_pipeline:
                state["train_pipeline"] = self.train_pipeline.metrics.save()
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.state = state["trainer_state"].copy()
            if self.train_pipeline:
                self.train_pipeline.metrics.restore(state["train_pipeline"])

    def with_updates(**overrides):
        """Build a copy of this trainer with the specified overrides.

        Arguments:
            overrides (dict): use this to override any of the arguments
                originally passed to build_trainer() for this policy.
        """
        return build_trainer(**dict(original_kwargs, **overrides))

    trainer_cls.with_updates = staticmethod(with_updates)
    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#19
0
def build_torch_policy(name,
                       *,
                       loss_fn,
                       get_default_config=None,
                       stats_fn=None,
                       postprocess_fn=None,
                       extra_action_out_fn=None,
                       extra_grad_process_fn=None,
                       optimizer_fn=None,
                       before_init=None,
                       after_init=None,
                       action_sampler_fn=None,
                       action_distribution_fn=None,
                       make_model_and_action_dist=None,
                       apply_gradients_fn=None,
                       mixins=None,
                       get_batch_divisibility_req=None):
    """Helper function for creating a torch policy at runtime.

    Arguments:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        loss_fn (func): function that returns a loss tensor as arguments
            (policy, model, dist_class, train_batch)
        get_default_config (func): optional function that returns the default
            config to merge with any overrides
        stats_fn (func): optional function that returns a dict of
            values given the policy and batch input tensors
        postprocess_fn (func): optional experience postprocessing function
            that takes the same args as Policy.postprocess_trajectory()
        extra_action_out_fn (func): optional function that returns
            a dict of extra values to include in experiences
        extra_grad_process_fn (func): optional function that is called after
            gradients are computed and returns processing info
        optimizer_fn (func): optional function that returns a torch optimizer
            given the policy and config
        before_init (func): optional function to run at the beginning of
            policy init that takes the same arguments as the policy constructor
        after_init (func): optional function to run at the end of policy init
            that takes the same arguments as the policy constructor
        action_sampler_fn (Optional[callable]): A callable returning a sampled
            action and its log-likelihood given some (obs and state) inputs.
        action_distribution_fn (Optional[callable]): A callable returning
            distribution inputs (parameters), a dist-class to generate an
            action distribution object from, and internal-state outputs (or an
            empty list if not applicable).
        make_model_and_action_dist (func): optional func that takes the same
            arguments as policy init and returns a tuple of model instance and
            torch action distribution class. If not specified, the default
            model and action dist from the catalog will be used
        apply_gradients_fn (Optional[callable]): An optional callable that
            takes a grads list and applies these to the Model's parameters.
        mixins (list): list of any class mixins for the returned policy class.
            These mixins will be applied in order and will have higher
            precedence than the TorchPolicy class
        get_batch_divisibility_req (Optional[callable]): Optional callable that
            returns the divisibility requirement for sample batches.

    Returns:
        a TorchPolicy instance that uses the specified args
    """

    original_kwargs = locals().copy()
    base = add_mixins(TorchPolicy, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
                # Make sure, we passed in a correct Model factory.
                assert isinstance(self.model, TorchModelV2), \
                    "ERROR: TorchPolicy::make_model_and_action_dist must " \
                    "return a TorchModelV2 object!"
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch",
                    **self.config["model"].get("custom_options", {}))

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak (issue #6962).
            with torch.no_grad():
                # Call super's postprocess_trajectory first.
                sample_batch = super().postprocess_trajectory(
                    convert_to_non_torch_type(sample_batch),
                    convert_to_non_torch_type(other_agent_batches), episode)
                if postprocess_fn:
                    return postprocess_fn(self, sample_batch,
                                          other_agent_batches, episode)

                return sample_batch

        @override(TorchPolicy)
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return TorchPolicy.extra_grad_process(self, optimizer, loss)

        @override(TorchPolicy)
        def apply_gradients(self, gradients):
            if apply_gradients_fn:
                apply_gradients_fn(self, gradients)
            else:
                TorchPolicy.apply_gradients(self, gradients)

        @override(TorchPolicy)
        def extra_action_out(self, input_dict, state_batches, model,
                             action_dist):
            with torch.no_grad():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(
                        self, input_dict, state_batches, model, action_dist)
                else:
                    stats_dict = TorchPolicy.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return convert_to_non_torch_type(stats_dict)

        @override(TorchPolicy)
        def optimizer(self):
            if optimizer_fn:
                return optimizer_fn(self, self.config)
            else:
                return TorchPolicy.optimizer(self)

        @override(TorchPolicy)
        def extra_grad_info(self, train_batch):
            with torch.no_grad():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
                return convert_to_non_torch_type(stats_dict)

    def with_updates(**overrides):
        return build_torch_policy(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
示例#20
0
def build_trainer(
    name: str,
    *,
    default_config: Optional[TrainerConfigDict] = None,
    validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
    default_policy: Optional[Type[Policy]] = None,
    get_policy_class: Optional[Callable[[TrainerConfigDict],
                                        Optional[Type[Policy]]]] = None,
    validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
    before_init: Optional[Callable[[Trainer], None]] = None,
    after_init: Optional[Callable[[Trainer], None]] = None,
    before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
    mixins: Optional[List[type]] = None,
    execution_plan: Optional[
        Callable[[WorkerSet, TrainerConfigDict],
                 Iterable[ResultDict]]] = default_execution_plan
) -> Type[Trainer]:
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_policy
        2. Worker setup: before_init, execution_plan
        3. Post setup: after_init

    Args:
        name (str): name of the trainer (e.g., "PPO")
        default_config (Optional[TrainerConfigDict]): The default config dict
            of the algorithm, otherwise uses the Trainer default config.
        validate_config (Optional[Callable[[TrainerConfigDict], None]]):
            Optional callable that takes the config to check for correctness.
            It may mutate the config as needed.
        default_policy (Optional[Type[Policy]]): The default Policy class to
            use if `get_policy_class` returns None.
        get_policy_class (Optional[Callable[
            TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable
            that takes a config and returns the policy class or None. If None
            is returned, will use `default_policy` (which must be provided
            then).
        validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
            Optional callable to validate the generated environment (only
            on worker=0).
        before_init (Optional[Callable[[Trainer], None]]): Optional callable to
            run before anything is constructed inside Trainer (Workers with
            Policies, execution plan, etc..). Takes the Trainer instance as
            argument.
        after_init (Optional[Callable[[Trainer], None]]): Optional callable to
            run at the end of trainer init (after all Workers and the exec.
            plan have been constructed). Takes the Trainer instance as
            argument.
        before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to
            run before evaluation. This takes the trainer instance as argument.
        mixins (list): list of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class.
        execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict],
            Iterable[ResultDict]]]): Optional callable that sets up the
            distributed execution workflow.

    Returns:
        Type[Trainer]: A Trainer sub-class configured by the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy_class = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config: TrainerConfigDict,
                  env_creator: Callable[[EnvConfigDict], EnvType]):
            # Validate config via custom validation function.
            if validate_config:
                validate_config(config)

            # No `get_policy_class` function.
            if get_policy_class is None:
                # Default_policy must be provided (unless in multi-agent mode,
                # where each policy can have its own default policy class.
                if not config["multiagent"]["policies"]:
                    assert default_policy is not None
                self._policy_class = default_policy
            # Query the function for a class to use.
            else:
                self._policy_class = get_policy_class(config)
                # If None returned, use default policy (must be provided).
                if self._policy_class is None:
                    assert default_policy is not None
                    self._policy_class = default_policy

            if before_init:
                before_init(self)

            # Creating all workers (excluding evaluation workers).
            self.workers = self._make_workers(
                env_creator=env_creator,
                validate_env=validate_env,
                policy_class=self._policy_class,
                config=config,
                num_workers=self.config["num_workers"])
            self.execution_plan = execution_plan
            self.train_exec_impl = execution_plan(self.workers, config)

            if after_init:
                after_init(self)

        @override(Trainer)
        def step(self):
            res = next(self.train_exec_impl)
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        @override(Trainer)
        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["train_exec_impl"] = (
                self.train_exec_impl.shared_metrics.get().save())
            return state

        @override(Trainer)
        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.train_exec_impl.shared_metrics.get().restore(
                state["train_exec_impl"])

        @staticmethod
        @override(Trainer)
        def with_updates(**overrides) -> Type[Trainer]:
            """Build a copy of this trainer class with the specified overrides.

            Keyword Args:
                overrides (dict): use this to override any of the arguments
                    originally passed to build_trainer() for this policy.

            Returns:
                Type[Trainer]: A the Trainer sub-class using `original_kwargs`
                    and `overrides`.

            Examples:
                >>> MyClass = SomeOtherClass.with_updates({"name": "Mine"})
                >>> issubclass(MyClass, SomeOtherClass)
                ... False
                >>> issubclass(MyClass, Trainer)
                ... True
            """
            return build_trainer(**dict(original_kwargs, **overrides))

    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#21
0
def build_trainer(name,
                  default_policy,
                  default_config=None,
                  validate_config=None,
                  get_initial_state=None,
                  get_policy_class=None,
                  before_init=None,
                  make_workers=None,
                  make_policy_optimizer=None,
                  after_init=None,
                  before_train_step=None,
                  after_optimizer_step=None,
                  after_train_result=None,
                  collect_metrics_fn=None,
                  before_evaluate_fn=None,
                  mixins=None,
                  execution_plan=None):
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_initial_state, get_policy
        2. Worker setup: before_init, make_workers, make_policy_optimizer
        3. Post setup: after_init

    Arguments:
        name (str): name of the trainer (e.g., "PPO")
        default_policy (cls): the default Policy class to use
        default_config (dict): The default config dict of the algorithm,
            otherwise uses the Trainer default config.
        validate_config (func): optional callback that checks a given config
            for correctness. It may mutate the config as needed.
        get_initial_state (func): optional function that returns the initial
            state dict given the trainer instance as an argument. The state
            dict must be serializable so that it can be checkpointed, and will
            be available as the `trainer.state` variable.
        get_policy_class (func): optional callback that takes a config and
            returns the policy class to override the default with
        before_init (func): optional function to run at the start of trainer
            init that takes the trainer instance as argument
        make_workers (func): override the method that creates rollout workers.
            This takes in (trainer, env_creator, policy, config) as args.
        make_policy_optimizer (func): optional function that returns a
            PolicyOptimizer instance given (WorkerSet, config)
        after_init (func): optional function to run at the end of trainer init
            that takes the trainer instance as argument
        before_train_step (func): optional callback to run before each train()
            call. It takes the trainer instance as an argument.
        after_optimizer_step (func): optional callback to run after each
            step() call to the policy optimizer. It takes the trainer instance
            and the policy gradient fetches as arguments.
        after_train_result (func): optional callback to run at the end of each
            train() call. It takes the trainer instance and result dict as
            arguments, and may mutate the result dict as needed.
        collect_metrics_fn (func): override the method used to collect metrics.
            It takes the trainer instance as argumnt.
        before_evaluate_fn (func): callback to run before evaluation. This
            takes the trainer instance as argument.
        mixins (list): list of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class
        execution_plan (func): Experimental distributed execution
            API. This overrides `make_policy_optimizer`.

    Returns:
        a Trainer instance that uses the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)

            if get_initial_state:
                self.state = get_initial_state(self)
            else:
                self.state = {}
            if get_policy_class is None:
                self._policy = default_policy
            else:
                self._policy = get_policy_class(config)
            if before_init:
                before_init(self)
            use_exec_api = (execution_plan
                            and (self.config["use_exec_api"]
                                 or "RLLIB_EXEC_API" in os.environ))

            # Creating all workers (excluding evaluation workers).
            if make_workers and not use_exec_api:
                self.workers = make_workers(self, env_creator, self._policy,
                                            config)
            else:
                self.workers = self._make_workers(env_creator, self._policy,
                                                  config,
                                                  self.config["num_workers"])
            self.train_exec_impl = None
            self.optimizer = None
            self.execution_plan = execution_plan

            if use_exec_api:
                self.train_exec_impl = execution_plan(self.workers, config)
            elif make_policy_optimizer:
                self.optimizer = make_policy_optimizer(self.workers, config)
            else:
                optimizer_config = dict(
                    config["optimizer"],
                    **{"train_batch_size": config["train_batch_size"]})
                self.optimizer = SyncSamplesOptimizer(self.workers,
                                                      **optimizer_config)
            if after_init:
                after_init(self)

        @override(Trainer)
        def _train(self):
            if self.train_exec_impl:
                return self._train_exec_impl()

            if before_train_step:
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            optimizer_steps_this_iter = 0
            while True:
                fetches = self.optimizer.step()
                optimizer_steps_this_iter += 1
                if after_optimizer_step:
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    break

            if collect_metrics_fn:
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(optimizer_steps_this_iter=optimizer_steps_this_iter,
                       timesteps_this_iter=self.optimizer.num_steps_sampled -
                       prev_steps,
                       info=res.get("info", {}))

            if after_train_result:
                after_train_result(self, res)
            return res

        def _train_exec_impl(self):
            if before_train_step:
                logger.debug("Ignoring before_train_step callback")
            res = next(self.train_exec_impl)
            if after_train_result:
                logger.debug("Ignoring after_train_result callback")
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["trainer_state"] = self.state.copy()
            if self.train_exec_impl:
                state["train_exec_impl"] = (
                    self.train_exec_impl.shared_metrics.get().save())
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.state = state["trainer_state"].copy()
            if self.train_exec_impl:
                self.train_exec_impl.shared_metrics.get().restore(
                    state["train_exec_impl"])

    def with_updates(**overrides):
        """Build a copy of this trainer with the specified overrides.

        Arguments:
            overrides (dict): use this to override any of the arguments
                originally passed to build_trainer() for this policy.
        """
        return build_trainer(**dict(original_kwargs, **overrides))

    trainer_cls.with_updates = staticmethod(with_updates)
    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls
示例#22
0
def build_eager_tf_policy(
        name,
        loss_fn,
        get_default_config=None,
        postprocess_fn=None,
        stats_fn=None,
        optimizer_fn=None,
        gradients_fn=None,
        apply_gradients_fn=None,
        grad_stats_fn=None,
        extra_learn_fetches_fn=None,
        extra_action_out_fn=None,
        validate_spaces=None,
        before_init=None,
        before_loss_init=None,
        after_init=None,
        make_model=None,
        action_sampler_fn=None,
        action_distribution_fn=None,
        mixins=None,
        obs_include_prev_action_reward=DEPRECATED_VALUE,
        get_batch_divisibility_req=None,
        # Deprecated args.
        extra_action_fetches_fn=None):
    """Build an eager TF policy.

    An eager policy runs all operations in eager mode, which makes debugging
    much simpler, but has lower performance.

    You shouldn't need to call this directly. Rather, prefer to build a TF
    graph policy and use set {"framework": "tfe"} in the trainer config to have
    it automatically be converted to an eager policy.

    This has the same signature as build_tf_policy()."""

    base = add_mixins(Policy, mixins)

    if extra_action_fetches_fn is not None:
        deprecation_warning(
            old="extra_action_fetches_fn",
            new="extra_action_out_fn",
            error=False)
        extra_action_out_fn = extra_action_fetches_fn

    if obs_include_prev_action_reward != DEPRECATED_VALUE:
        deprecation_warning(old="obs_include_prev_action_reward", error=False)

    class eager_policy_cls(base):
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)

            # Log device and worker index.
            from ray.rllib.evaluation.rollout_worker import get_global_worker
            worker = get_global_worker()
            worker_idx = worker.worker_index if worker else 0
            if tf.config.list_physical_devices("GPU"):
                logger.info(
                    "TF-eager Policy (worker={}) running on GPU.".format(
                        worker_idx if worker_idx > 0 else "local"))
            else:
                logger.info(
                    "TF-eager Policy (worker={}) running on CPU.".format(
                        worker_idx if worker_idx > 0 else "local"))

            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            self._loss = loss_fn
            self.batch_divisibility_req = get_batch_divisibility_req(self) if \
                callable(get_batch_divisibility_req) else \
                (get_batch_divisibility_req or 1)
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Lock used for locking some methods on the object-level.
            # This prevents possible race conditions when calling the model
            # first, then its value function (e.g. in a loss function), in
            # between of which another model call is made (e.g. to compute an
            # action).
            self._lock = threading.RLock()

            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_inputs = self.model.get_initial_state()
            self._is_recurrent = len(self._state_inputs) > 0

            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(self.model.view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)
            # TODO: (sven) Allow tf policy to have more than 1 optimizer.
            #  Just like torch Policy does.
            self._optimizer = optimizers[0] if optimizers else None

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after fake run-throughs.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            assert tf.executing_eagerly()
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @with_lock
        @override(Policy)
        def learn_on_batch(self, postprocessed_batch):
            # Callback handling.
            learn_stats = {}
            self.callbacks.on_learn_on_batch(
                policy=self,
                train_batch=postprocessed_batch,
                result=learn_stats)

            if not isinstance(postprocessed_batch, SampleBatch) or \
                    not postprocessed_batch.zero_padded:
                pad_batch_to_sequences_of_same_size(
                    postprocessed_batch,
                    max_seq_len=self._max_seq_len,
                    shuffle=False,
                    batch_divisibility_req=self.batch_divisibility_req,
                    view_requirements=self.view_requirements,
                )
            else:
                postprocessed_batch["seq_lens"] = postprocessed_batch.seq_lens

            self._is_training = True
            postprocessed_batch["is_training"] = True
            stats = self._learn_on_batch_eager(postprocessed_batch)
            stats.update({"custom_metrics": learn_stats})
            return stats

        @convert_eager_inputs
        @convert_eager_outputs
        def _learn_on_batch_eager(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            self._apply_gradients(grads_and_vars)
            return stats

        @override(Policy)
        def compute_gradients(self, samples):
            pad_batch_to_sequences_of_same_size(
                samples,
                shuffle=False,
                max_seq_len=self._max_seq_len,
                batch_divisibility_req=self.batch_divisibility_req)

            self._is_training = True
            samples["is_training"] = True
            return self._compute_gradients_eager(samples)

        @convert_eager_inputs
        @convert_eager_outputs
        def _compute_gradients_eager(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            grads = [g for g, v in grads_and_vars]
            return grads, stats

        @override(Policy)
        def compute_actions(self,
                            obs_batch,
                            state_batches=None,
                            prev_action_batch=None,
                            prev_reward_batch=None,
                            info_batch=None,
                            episodes=None,
                            explore=None,
                            timestep=None,
                            **kwargs):

            self._is_training = False
            self._is_recurrent = \
                state_batches is not None and state_batches != []

            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()

            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = \
                    tf.convert_to_tensor(prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = \
                    tf.convert_to_tensor(prev_reward_batch)

            return self._compute_action_helper(input_dict, state_batches,
                                               episodes, explore, timestep)

        @override(Policy)
        def compute_actions_from_input_dict(
                self,
                input_dict: Dict[str, TensorType],
                explore: bool = None,
                timestep: Optional[int] = None,
                **kwargs
        ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()

            # Pass lazy (eager) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]

            return self._compute_action_helper(input_dict, state_batches, None,
                                               explore, timestep)

        @with_lock
        @convert_eager_inputs
        @convert_eager_outputs
        def _compute_action_helper(self, input_dict, state_batches, episodes,
                                   explore, timestep):

            explore = explore if explore is not None else \
                self.config["explore"]
            timestep = timestep if timestep is not None else \
                self.global_timestep
            if isinstance(timestep, tf.Tensor):
                timestep = int(timestep.numpy())
            self._is_training = False
            self._state_in = state_batches or []
            # Calculate RNN sequence lengths.
            batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
            seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
                else None

            # Add default and custom fetches.
            extra_fetches = {}

            # Use Exploration object.
            with tf.variable_creator_scope(_disallow_var_creation):
                if action_sampler_fn:
                    dist_inputs = None
                    state_out = []
                    actions, logp = action_sampler_fn(
                        self,
                        self.model,
                        input_dict[SampleBatch.CUR_OBS],
                        explore=explore,
                        timestep=timestep,
                        episodes=episodes)
                else:
                    # Exploration hook before each forward pass.
                    self.exploration.before_compute_actions(
                        timestep=timestep, explore=explore)

                    if action_distribution_fn:

                        # Try new action_distribution_fn signature, supporting
                        # state_batches and seq_lens.
                        try:
                            dist_inputs, self.dist_class, state_out = \
                                action_distribution_fn(
                                    self,
                                    self.model,
                                    input_dict=input_dict,
                                    state_batches=state_batches,
                                    seq_lens=seq_lens,
                                    explore=explore,
                                    timestep=timestep,
                                    is_training=False)
                        # Trying the old way (to stay backward compatible).
                        # TODO: Remove in future.
                        except TypeError as e:
                            if "positional argument" in e.args[0] or \
                                    "unexpected keyword argument" in e.args[0]:
                                dist_inputs, self.dist_class, state_out = \
                                    action_distribution_fn(
                                        self, self.model,
                                        input_dict[SampleBatch.CUR_OBS],
                                        explore=explore,
                                        timestep=timestep,
                                        is_training=False)
                            else:
                                raise e
                    elif isinstance(self.model, tf.keras.Model):
                        input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
                        self._lazy_tensor_dict(input_dict)
                        dist_inputs, state_out, extra_fetches = \
                            self.model(input_dict)
                    else:
                        dist_inputs, state_out = self.model(
                            input_dict, state_batches, seq_lens)

                    action_dist = self.dist_class(dist_inputs, self.model)

                    # Get the exploration action from the forward results.
                    actions, logp = self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

            # Action-logp and action-prob.
            if logp is not None:
                extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            # Custom extra fetches.
            if extra_action_out_fn:
                extra_fetches.update(extra_action_out_fn(self))

            # Update our global timestep by the batch size.
            self.global_timestep += int(batch_size)

            return actions, state_out, extra_fetches

        @with_lock
        @override(Policy)
        def compute_log_likelihoods(self,
                                    actions,
                                    obs_batch,
                                    state_batches=None,
                                    prev_action_batch=None,
                                    prev_reward_batch=None):
            if action_sampler_fn and action_distribution_fn is None:
                raise ValueError("Cannot compute log-prob/likelihood w/o an "
                                 "`action_distribution_fn` and a provided "
                                 "`action_sampler_fn`!")

            seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = \
                    tf.convert_to_tensor(prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = \
                    tf.convert_to_tensor(prev_reward_batch)

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if action_distribution_fn:
                dist_inputs, dist_class, _ = action_distribution_fn(
                    self,
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=False,
                    is_training=False)
            # Default log-likelihood calculation.
            else:
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)
                dist_class = self.dist_class

            action_dist = dist_class(dist_inputs, self.model)
            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods

        @override(Policy)
        def apply_gradients(self, gradients):
            self._apply_gradients(
                zip([(tf.convert_to_tensor(g) if g is not None else None)
                     for g in gradients], self.model.trainable_variables()))

        @override(Policy)
        def get_exploration_info(self):
            return _convert_to_numpy(self.exploration.get_info())

        @override(Policy)
        def get_weights(self, as_dict=False):
            variables = self.variables()
            if as_dict:
                return {v.name: v.numpy() for v in variables}
            return [v.numpy() for v in variables]

        @override(Policy)
        def set_weights(self, weights):
            variables = self.variables()
            assert len(weights) == len(variables), (len(weights),
                                                    len(variables))
            for v, w in zip(variables, weights):
                v.assign(w)

        @override(Policy)
        def get_state(self):
            state = {"_state": super().get_state()}
            if self._optimizer and \
                    len(self._optimizer.variables()) > 0:
                state["_optimizer_variables"] = \
                    self._optimizer.variables()
            return state

        @override(Policy)
        def set_state(self, state):
            state = state.copy()  # shallow copy
            # Set optimizer vars first.
            optimizer_vars = state.pop("_optimizer_variables", None)
            if optimizer_vars and self._optimizer.variables():
                logger.warning(
                    "Cannot restore an optimizer's state for tf eager! Keras "
                    "is not able to save the v1.x optimizers (from "
                    "tf.compat.v1.train) since they aren't compatible with "
                    "checkpoints.")
                for opt_var, value in zip(self._optimizer.variables(),
                                          optimizer_vars):
                    opt_var.assign(value)
            # Then the Policy's (NN) weights.
            super().set_state(state["_state"])

        def variables(self):
            """Return the list of all savable variables for this policy."""
            if isinstance(self.model, tf.keras.Model):
                return self.model.variables
            else:
                return self.model.variables()

        @override(Policy)
        def is_recurrent(self):
            return self._is_recurrent

        @override(Policy)
        def num_state_tensors(self):
            return len(self._state_inputs)

        @override(Policy)
        def get_initial_state(self):
            if hasattr(self, "model"):
                return self.model.get_initial_state()
            return []

        def get_session(self):
            return None  # None implies eager

        def get_placeholder(self, ph):
            raise ValueError(
                "get_placeholder() is not allowed in eager mode. Try using "
                "rllib.utils.tf_ops.make_tf_callable() to write "
                "functions that work in both graph and eager mode.")

        def loss_initialized(self):
            return self._loss_initialized

        @override(Policy)
        def export_model(self, export_dir):
            pass

        @override(Policy)
        def export_checkpoint(self, export_dir):
            pass

        def _get_is_training_placeholder(self):
            return tf.convert_to_tensor(self._is_training)

        def _apply_gradients(self, grads_and_vars):
            if apply_gradients_fn:
                apply_gradients_fn(self, self._optimizer, grads_and_vars)
            else:
                self._optimizer.apply_gradients(
                    [(g, v) for g, v in grads_and_vars if g is not None])

        @with_lock
        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                loss = loss_fn(self, self.model, self.dist_class, samples)

            if isinstance(self.model, tf.keras.Model):
                variables = self.model.trainable_variables
            else:
                variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper:
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats

        def _stats(self, outputs, samples, grads):

            fetches = {}
            if stats_fn:
                fetches[LEARNER_STATS_KEY] = {
                    k: v
                    for k, v in stats_fn(outputs, samples).items()
                }
            else:
                fetches[LEARNER_STATS_KEY] = {}

            if extra_learn_fetches_fn:
                fetches.update(
                    {k: v
                     for k, v in extra_learn_fetches_fn(self).items()})
            if grad_stats_fn:
                fetches.update({
                    k: v
                    for k, v in grad_stats_fn(self, samples, grads).items()
                })
            return fetches

        def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
            # TODO: (sven): Keep for a while to ensure backward compatibility.
            if not isinstance(postprocessed_batch, SampleBatch):
                postprocessed_batch = SampleBatch(postprocessed_batch)
            postprocessed_batch.set_get_interceptor(_convert_to_tf)
            return postprocessed_batch

        @classmethod
        def with_tracing(cls):
            return traced_eager_policy(cls)

    eager_policy_cls.__name__ = name + "_eager"
    eager_policy_cls.__qualname__ = name + "_eager"
    return eager_policy_cls
示例#23
0
def build_eager_tf_policy(name,
                          loss_fn,
                          get_default_config=None,
                          postprocess_fn=None,
                          stats_fn=None,
                          optimizer_fn=None,
                          gradients_fn=None,
                          apply_gradients_fn=None,
                          grad_stats_fn=None,
                          extra_learn_fetches_fn=None,
                          extra_action_fetches_fn=None,
                          before_init=None,
                          before_loss_init=None,
                          after_init=None,
                          make_model=None,
                          action_sampler_fn=None,
                          log_likelihood_fn=None,
                          mixins=None,
                          obs_include_prev_action_reward=True,
                          get_batch_divisibility_req=None):
    """Build an eager TF policy.

    An eager policy runs all operations in eager mode, which makes debugging
    much simpler, but has lower performance.

    You shouldn't need to call this directly. Rather, prefer to build a TF
    graph policy and use set {"eager": true} in the trainer config to have
    it automatically be converted to an eager policy.

    This has the same signature as build_tf_policy()."""

    base = add_mixins(Policy, mixins)

    class eager_policy_cls(base):
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = "tf"
            Policy.__init__(self, observation_space, action_space, config)
            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            if get_default_config:
                config = dict(get_default_config(), **config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None

            if action_sampler_fn:
                if not make_model:
                    raise ValueError("`make_model` is required if "
                                     "`action_sampler_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework="tf",
                )

            self._state_in = [
                tf.convert_to_tensor(np.array([s]))
                for s in self.model.get_initial_state()
            ]

            input_dict = {
                SampleBatch.CUR_OBS:
                tf.convert_to_tensor(np.array([observation_space.sample()])),
                SampleBatch.PREV_ACTIONS:
                tf.convert_to_tensor([_flatten_action(action_space.sample())]),
                SampleBatch.PREV_REWARDS:
                tf.convert_to_tensor([0.]),
            }
            self.model(input_dict, self._state_in, tf.convert_to_tensor([1]))

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            self._initialize_loss_with_dummy_batch()
            self._loss_initialized = True

            if optimizer_fn:
                self._optimizer = optimizer_fn(self, config)
            else:
                self._optimizer = tf.train.AdamOptimizer(config["lr"])

            if after_init:
                after_init(self, observation_space, action_space, config)

        @override(Policy)
        def postprocess_trajectory(self,
                                   samples,
                                   other_agent_batches=None,
                                   episode=None):
            assert tf.executing_eagerly()
            if postprocess_fn:
                return postprocess_fn(self, samples, other_agent_batches,
                                      episode)
            else:
                return samples

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def learn_on_batch(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            self._apply_gradients(grads_and_vars)
            return stats

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def compute_gradients(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            grads = [g for g, v in grads_and_vars]
            return grads, stats

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def compute_actions(self,
                            obs_batch,
                            state_batches,
                            prev_action_batch=None,
                            prev_reward_batch=None,
                            info_batch=None,
                            episodes=None,
                            explore=None,
                            timestep=None,
                            **kwargs):

            explore = explore if explore is not None else \
                self.config["explore"]

            # TODO: remove python side effect to cull sources of bugs.
            self._is_training = False
            self._state_in = state_batches

            if tf.executing_eagerly():
                n = len(obs_batch)
            else:
                n = obs_batch.shape[0]
            seq_lens = tf.ones(n, dtype=tf.int32)

            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            if obs_include_prev_action_reward:
                input_dict.update({
                    SampleBatch.PREV_ACTIONS:
                    tf.convert_to_tensor(prev_action_batch),
                    SampleBatch.PREV_REWARDS:
                    tf.convert_to_tensor(prev_reward_batch),
                })

            # Custom sampler fn given (which may handle self.exploration).
            if action_sampler_fn is not None:
                state_out = []
                action, logp = action_sampler_fn(
                    self,
                    self.model,
                    input_dict,
                    self.observation_space,
                    self.action_space,
                    explore,
                    self.config,
                    timestep=timestep
                    if timestep is not None else self.global_timestep)
            # Use Exploration object.
            else:
                with tf.variable_creator_scope(_disallow_var_creation):
                    model_out, state_out = self.model(input_dict,
                                                      state_batches, seq_lens)
                    action, logp = self.exploration.get_exploration_action(
                        model_out,
                        self.dist_class,
                        self.model,
                        timestep=timestep
                        if timestep is not None else self.global_timestep,
                        explore=explore)

            extra_fetches = {}
            if logp is not None:
                extra_fetches.update({
                    ACTION_PROB: tf.exp(logp),
                    ACTION_LOGP: logp,
                })
            if extra_action_fetches_fn:
                extra_fetches.update(extra_action_fetches_fn(self))

            # Increase our global sampling timestep counter by 1.
            self.global_timestep += 1

            return action, state_out, extra_fetches

        @override(Policy)
        def compute_log_likelihoods(self,
                                    actions,
                                    obs_batch,
                                    state_batches=None,
                                    prev_action_batch=None,
                                    prev_reward_batch=None):

            seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            if obs_include_prev_action_reward:
                input_dict.update({
                    SampleBatch.PREV_ACTIONS:
                    tf.convert_to_tensor(prev_action_batch),
                    SampleBatch.PREV_REWARDS:
                    tf.convert_to_tensor(prev_reward_batch),
                })

            # Custom log_likelihood function given.
            if log_likelihood_fn:
                log_likelihoods = log_likelihood_fn(self, self.model, actions,
                                                    input_dict,
                                                    self.observation_space,
                                                    self.action_space,
                                                    self.config)
            # Default log-likelihood calculation.
            else:
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)
                action_dist = self.dist_class(dist_inputs, self.model)
                log_likelihoods = action_dist.logp(actions)

            return log_likelihoods

        @override(Policy)
        def apply_gradients(self, gradients):
            self._apply_gradients(
                zip([(tf.convert_to_tensor(g) if g is not None else None)
                     for g in gradients], self.model.trainable_variables()))

        @override(Policy)
        def get_exploration_info(self):
            return _convert_to_numpy(self.exploration.get_info())

        @override(Policy)
        def get_weights(self):
            variables = self.variables()
            return [v.numpy() for v in variables]

        @override(Policy)
        def set_weights(self, weights):
            variables = self.variables()
            assert len(weights) == len(variables), (len(weights),
                                                    len(variables))
            for v, w in zip(variables, weights):
                v.assign(w)

        def variables(self):
            """Return the list of all savable variables for this policy."""
            return self.model.variables()

        @override(Policy)
        def is_recurrent(self):
            return len(self._state_in) > 0

        @override(Policy)
        def num_state_tensors(self):
            return len(self._state_in)

        @override(Policy)
        def get_initial_state(self):
            return self.model.get_initial_state()

        def get_session(self):
            return None  # None implies eager

        def get_placeholder(self, ph):
            raise ValueError(
                "get_placeholder() is not allowed in eager mode. Try using "
                "rllib.utils.tf_ops.make_tf_callable() to write "
                "functions that work in both graph and eager mode.")

        def loss_initialized(self):
            return self._loss_initialized

        @override(Policy)
        def export_model(self, export_dir):
            pass

        @override(Policy)
        def export_checkpoint(self, export_dir):
            pass

        def _get_is_training_placeholder(self):
            return tf.convert_to_tensor(self._is_training)

        def _apply_gradients(self, grads_and_vars):
            if apply_gradients_fn:
                apply_gradients_fn(self, self._optimizer, grads_and_vars)
            else:
                self._optimizer.apply_gradients(grads_and_vars)

        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            self._is_training = True

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                # TODO: set seq len and state-in properly
                state_in = []
                for i in range(self.num_state_tensors()):
                    state_in.append(samples["state_in_{}".format(i)])
                self._state_in = state_in

                self._seq_lens = None
                if len(state_in) > 0:
                    self._seq_lens = tf.ones(
                        samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32)
                    samples["seq_lens"] = self._seq_lens

                model_out, _ = self.model(samples, self._state_in,
                                          self._seq_lens)
                loss = loss_fn(self, self.model, self.dist_class, samples)

            variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper:
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats

        def _stats(self, outputs, samples, grads):

            fetches = {}
            if stats_fn:
                fetches[LEARNER_STATS_KEY] = {
                    k: v
                    for k, v in stats_fn(outputs, samples).items()
                }
            else:
                fetches[LEARNER_STATS_KEY] = {}
            if extra_learn_fetches_fn:
                fetches.update(
                    {k: v
                     for k, v in extra_learn_fetches_fn(self).items()})
            if grad_stats_fn:
                fetches.update({
                    k: v
                    for k, v in grad_stats_fn(self, samples, grads).items()
                })
            return fetches

        def _initialize_loss_with_dummy_batch(self):
            # Dummy forward pass to initialize any policy attributes, etc.
            action_dtype, action_shape = ModelCatalog.get_action_shape(
                self.action_space)
            dummy_batch = {
                SampleBatch.CUR_OBS:
                np.array([self.observation_space.sample()]),
                SampleBatch.NEXT_OBS:
                np.array([self.observation_space.sample()]),
                SampleBatch.DONES:
                np.array([False], dtype=np.bool),
                SampleBatch.ACTIONS:
                tf.nest.map_structure(lambda c: np.array([c]),
                                      self.action_space.sample()),
                SampleBatch.REWARDS:
                np.array([0], dtype=np.float32),
            }
            if obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    dummy_batch[SampleBatch.ACTIONS],
                    SampleBatch.PREV_REWARDS:
                    dummy_batch[SampleBatch.REWARDS],
                })
            for i, h in enumerate(self._state_in):
                dummy_batch["state_in_{}".format(i)] = h
                dummy_batch["state_out_{}".format(i)] = h

            if self._state_in:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)

            # Convert everything to tensors.
            dummy_batch = tf.nest.map_structure(tf.convert_to_tensor,
                                                dummy_batch)

            # for IMPALA which expects a certain sample batch size.
            def tile_to(tensor, n):
                return tf.tile(tensor,
                               [n] + [1 for _ in tensor.shape.as_list()[1:]])

            if get_batch_divisibility_req:
                dummy_batch = tf.nest.map_structure(
                    lambda c: tile_to(c, get_batch_divisibility_req(self)),
                    dummy_batch)

            # Execute a forward pass to get self.action_dist etc initialized,
            # and also obtain the extra action fetches
            _, _, fetches = self.compute_actions(
                dummy_batch[SampleBatch.CUR_OBS], self._state_in,
                dummy_batch.get(SampleBatch.PREV_ACTIONS),
                dummy_batch.get(SampleBatch.PREV_REWARDS))
            dummy_batch.update(fetches)

            postprocessed_batch = self.postprocess_trajectory(
                SampleBatch(dummy_batch))

            # model forward pass for the loss (needed after postprocess to
            # overwrite any tensor state from that call)
            self.model.from_batch(dummy_batch)

            postprocessed_batch = tf.nest.map_structure(
                lambda c: tf.convert_to_tensor(c), postprocessed_batch.data)

            loss_fn(self, self.model, self.dist_class, postprocessed_batch)
            if stats_fn:
                stats_fn(self, postprocessed_batch)

        @classmethod
        def with_tracing(cls):
            return traced_eager_policy(cls)

    eager_policy_cls.__name__ = name + "_eager"
    eager_policy_cls.__qualname__ = name + "_eager"
    return eager_policy_cls
示例#24
0
def build_eager_tf_policy(name,
                          loss_fn,
                          get_default_config=None,
                          postprocess_fn=None,
                          stats_fn=None,
                          optimizer_fn=None,
                          gradients_fn=None,
                          apply_gradients_fn=None,
                          grad_stats_fn=None,
                          extra_learn_fetches_fn=None,
                          extra_action_fetches_fn=None,
                          validate_spaces=None,
                          before_init=None,
                          before_loss_init=None,
                          after_init=None,
                          make_model=None,
                          action_sampler_fn=None,
                          action_distribution_fn=None,
                          mixins=None,
                          view_requirements_fn=None,
                          obs_include_prev_action_reward=True,
                          get_batch_divisibility_req=None):
    """Build an eager TF policy.

    An eager policy runs all operations in eager mode, which makes debugging
    much simpler, but has lower performance.

    You shouldn't need to call this directly. Rather, prefer to build a TF
    graph policy and use set {"framework": "tfe"} in the trainer config to have
    it automatically be converted to an eager policy.

    This has the same signature as build_tf_policy()."""

    base = add_mixins(Policy, mixins)

    class eager_policy_cls(base):
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)
            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            self._loss = loss_fn
            self.batch_divisibility_req = get_batch_divisibility_req(self) if \
                callable(get_batch_divisibility_req) else \
                (get_batch_divisibility_req or 1)
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_inference_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_in = [
                tf.convert_to_tensor([s])
                for s in self.model.get_initial_state()
            ]

            # Update this Policy's ViewRequirements (if function given).
            if callable(view_requirements_fn):
                self.view_requirements.update(view_requirements_fn(self))
            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(
                self.model.inference_view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)
            # TODO: (sven) Allow tf policy to have more than 1 optimizer.
            #  Just like torch Policy does.
            self._optimizer = optimizers[0] if optimizers else None

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            assert tf.executing_eagerly()
            # Call super's postprocess_trajectory first.
            sample_batch = Policy.postprocess_trajectory(self, sample_batch)
            if postprocess_fn:
                return postprocess_fn(self, sample_batch, other_agent_batches,
                                      episode)
            return sample_batch

        @override(Policy)
        def learn_on_batch(self, postprocessed_batch):
            # Callback handling.
            self.callbacks.on_learn_on_batch(policy=self,
                                             train_batch=postprocessed_batch)

            # Get batch ready for RNNs, if applicable.
            pad_batch_to_sequences_of_same_size(
                postprocessed_batch,
                shuffle=False,
                max_seq_len=self._max_seq_len,
                batch_divisibility_req=self.batch_divisibility_req)
            return self._learn_on_batch_eager(postprocessed_batch)

        @convert_eager_inputs
        @convert_eager_outputs
        def _learn_on_batch_eager(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            self._apply_gradients(grads_and_vars)
            return stats

        @override(Policy)
        def compute_gradients(self, samples):
            # Get batch ready for RNNs, if applicable.
            pad_batch_to_sequences_of_same_size(
                samples,
                shuffle=False,
                max_seq_len=self._max_seq_len,
                batch_divisibility_req=self.batch_divisibility_req)
            return self._compute_gradients_eager(samples)

        @convert_eager_inputs
        @convert_eager_outputs
        def _compute_gradients_eager(self, samples):
            with tf.variable_creator_scope(_disallow_var_creation):
                grads_and_vars, stats = self._compute_gradients(samples)
            grads = [g for g, v in grads_and_vars]
            return grads, stats

        @override(Policy)
        @convert_eager_inputs
        @convert_eager_outputs
        def compute_actions(self,
                            obs_batch,
                            state_batches=None,
                            prev_action_batch=None,
                            prev_reward_batch=None,
                            info_batch=None,
                            episodes=None,
                            explore=None,
                            timestep=None,
                            **kwargs):

            explore = explore if explore is not None else \
                self.config["explore"]
            timestep = timestep if timestep is not None else \
                self.global_timestep

            # TODO: remove python side effect to cull sources of bugs.
            self._is_training = False
            self._state_in = state_batches

            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()

            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
            seq_lens = tf.ones(batch_size, dtype=tf.int32)
            if obs_include_prev_action_reward:
                if prev_action_batch is not None:
                    input_dict[SampleBatch.PREV_ACTIONS] = \
                        tf.convert_to_tensor(prev_action_batch)
                if prev_reward_batch is not None:
                    input_dict[SampleBatch.PREV_REWARDS] = \
                        tf.convert_to_tensor(prev_reward_batch)

            # Use Exploration object.
            with tf.variable_creator_scope(_disallow_var_creation):
                if action_sampler_fn:
                    dist_inputs = None
                    state_out = []
                    actions, logp = self.action_sampler_fn(
                        self,
                        self.model,
                        input_dict[SampleBatch.CUR_OBS],
                        explore=explore,
                        timestep=timestep,
                        episodes=episodes)
                else:
                    # Exploration hook before each forward pass.
                    self.exploration.before_compute_actions(timestep=timestep,
                                                            explore=explore)

                    if action_distribution_fn:
                        dist_inputs, dist_class, state_out = \
                            action_distribution_fn(
                                self, self.model,
                                input_dict[SampleBatch.CUR_OBS],
                                explore=explore,
                                timestep=timestep,
                                is_training=False)
                    else:
                        dist_class = self.dist_class
                        dist_inputs, state_out = self.model(
                            input_dict, state_batches, seq_lens)

                    action_dist = dist_class(dist_inputs, self.model)

                    # Get the exploration action from the forward results.
                    actions, logp = self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

            # Add default and custom fetches.
            extra_fetches = {}
            # Action-logp and action-prob.
            if logp is not None:
                extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            # Custom extra fetches.
            if extra_action_fetches_fn:
                extra_fetches.update(extra_action_fetches_fn(self))

            # Update our global timestep by the batch size.
            self.global_timestep += int(batch_size)

            return actions, state_out, extra_fetches

        @override(Policy)
        def compute_log_likelihoods(self,
                                    actions,
                                    obs_batch,
                                    state_batches=None,
                                    prev_action_batch=None,
                                    prev_reward_batch=None):
            if action_sampler_fn and action_distribution_fn is None:
                raise ValueError("Cannot compute log-prob/likelihood w/o an "
                                 "`action_distribution_fn` and a provided "
                                 "`action_sampler_fn`!")

            seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
            input_dict = {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                "is_training": tf.constant(False),
            }
            if obs_include_prev_action_reward:
                input_dict.update({
                    SampleBatch.PREV_ACTIONS:
                    tf.convert_to_tensor(prev_action_batch),
                    SampleBatch.PREV_REWARDS:
                    tf.convert_to_tensor(prev_reward_batch),
                })

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if action_distribution_fn:
                dist_inputs, dist_class, _ = action_distribution_fn(
                    self,
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=False,
                    is_training=False)
                action_dist = dist_class(dist_inputs, self.model)
                log_likelihoods = action_dist.logp(actions)
            # Default log-likelihood calculation.
            else:
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)
                dist_class = self.dist_class

            action_dist = dist_class(dist_inputs, self.model)
            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods

        @override(Policy)
        def apply_gradients(self, gradients):
            self._apply_gradients(
                zip([(tf.convert_to_tensor(g) if g is not None else None)
                     for g in gradients], self.model.trainable_variables()))

        @override(Policy)
        def get_exploration_info(self):
            return _convert_to_numpy(self.exploration.get_info())

        @override(Policy)
        def get_weights(self, as_dict=False):
            variables = self.variables()
            if as_dict:
                return {v.name: v.numpy() for v in variables}
            return [v.numpy() for v in variables]

        @override(Policy)
        def set_weights(self, weights):
            variables = self.variables()
            assert len(weights) == len(variables), (len(weights),
                                                    len(variables))
            for v, w in zip(variables, weights):
                v.assign(w)

        @override(Policy)
        def get_state(self):
            state = {"_state": super().get_state()}
            state["_optimizer_variables"] = self._optimizer.variables()
            return state

        @override(Policy)
        def set_state(self, state):
            state = state.copy()  # shallow copy
            # Set optimizer vars first.
            optimizer_vars = state.pop("_optimizer_variables", None)
            if optimizer_vars and self._optimizer.variables():
                logger.warning(
                    "Cannot restore an optimizer's state for tf eager! Keras "
                    "is not able to save the v1.x optimizers (from "
                    "tf.compat.v1.train) since they aren't compatible with "
                    "checkpoints.")
                for opt_var, value in zip(self._optimizer.variables(),
                                          optimizer_vars):
                    opt_var.assign(value)
            # Then the Policy's (NN) weights.
            super().set_state(state["_state"])

        def variables(self):
            """Return the list of all savable variables for this policy."""
            return self.model.variables()

        @override(Policy)
        def is_recurrent(self):
            return len(self._state_in) > 0

        @override(Policy)
        def num_state_tensors(self):
            return len(self._state_in)

        @override(Policy)
        def get_initial_state(self):
            return self.model.get_initial_state()

        def get_session(self):
            return None  # None implies eager

        def get_placeholder(self, ph):
            raise ValueError(
                "get_placeholder() is not allowed in eager mode. Try using "
                "rllib.utils.tf_ops.make_tf_callable() to write "
                "functions that work in both graph and eager mode.")

        def loss_initialized(self):
            return self._loss_initialized

        @override(Policy)
        def export_model(self, export_dir):
            pass

        @override(Policy)
        def export_checkpoint(self, export_dir):
            pass

        def _get_is_training_placeholder(self):
            return tf.convert_to_tensor(self._is_training)

        def _apply_gradients(self, grads_and_vars):
            if apply_gradients_fn:
                apply_gradients_fn(self, self._optimizer, grads_and_vars)
            else:
                self._optimizer.apply_gradients([(g, v)
                                                 for g, v in grads_and_vars
                                                 if g is not None])

        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            self._is_training = True

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                # TODO: set seq len and state-in properly
                state_in = []
                for i in range(self.num_state_tensors()):
                    state_in.append(samples["state_in_{}".format(i)])
                self._state_in = state_in

                model_out, _ = self.model(samples, self._state_in,
                                          samples.get("seq_lens"))
                loss = loss_fn(self, self.model, self.dist_class, samples)

            variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper:
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats

        def _stats(self, outputs, samples, grads):

            fetches = {}
            if stats_fn:
                fetches[LEARNER_STATS_KEY] = {
                    k: v
                    for k, v in stats_fn(outputs, samples).items()
                }
            else:
                fetches[LEARNER_STATS_KEY] = {}

            if extra_learn_fetches_fn:
                fetches.update(
                    {k: v
                     for k, v in extra_learn_fetches_fn(self).items()})
            if grad_stats_fn:
                fetches.update({
                    k: v
                    for k, v in grad_stats_fn(self, samples, grads).items()
                })
            return fetches

        def _lazy_tensor_dict(self, postprocessed_batch):
            train_batch = UsageTrackingDict(postprocessed_batch)
            train_batch.set_get_interceptor(_convert_to_tf)
            return train_batch

        def _lazy_numpy_dict(self, postprocessed_batch):
            train_batch = UsageTrackingDict(postprocessed_batch)
            train_batch.set_get_interceptor(convert_to_non_tf_type)
            return train_batch

        @classmethod
        def with_tracing(cls):
            return traced_eager_policy(cls)

    eager_policy_cls.__name__ = name + "_eager"
    eager_policy_cls.__qualname__ = name + "_eager"
    return eager_policy_cls
示例#25
0
def build_trainer(name,
                  default_policy,
                  default_config=None,
                  validate_config=None,
                  get_initial_state=None,
                  get_policy_class=None,
                  before_init=None,
                  make_workers=None,
                  make_policy_optimizer=None,
                  after_init=None,
                  before_train_step=None,
                  after_optimizer_step=None,
                  after_train_result=None,
                  collect_metrics_fn=None,
                  before_evaluate_fn=None,
                  mixins=None):
    """Helper function for defining a custom trainer.

    Functions will be run in this order to initialize the trainer:
        1. Config setup: validate_config, get_initial_state, get_policy
        2. Worker setup: before_init, make_workers, make_policy_optimizer
        3. Post setup: after_init

    Arguments:
        name (str): name of the trainer (e.g., "PPO")
        default_policy (cls): the default Policy class to use
        default_config (dict): the default config dict of the algorithm,
            otherwises uses the Trainer default config
        validate_config (func): optional callback that checks a given config
            for correctness. It may mutate the config as needed.
        get_initial_state (func): optional function that returns the initial
            state dict given the trainer instance as an argument. The state
            dict must be serializable so that it can be checkpointed, and will
            be available as the `trainer.state` variable.
        get_policy_class (func): optional callback that takes a config and
            returns the policy class to override the default with
        before_init (func): optional function to run at the start of trainer
            init that takes the trainer instance as argument
        make_workers (func): override the method that creates rollout workers.
            This takes in (trainer, env_creator, policy, config) as args.
        make_policy_optimizer (func): optional function that returns a
            PolicyOptimizer instance given (WorkerSet, config)
        after_init (func): optional function to run at the end of trainer init
            that takes the trainer instance as argument
        before_train_step (func): optional callback to run before each train()
            call. It takes the trainer instance as an argument.
        after_optimizer_step (func): optional callback to run after each
            step() call to the policy optimizer. It takes the trainer instance
            and the policy gradient fetches as arguments.
        after_train_result (func): optional callback to run at the end of each
            train() call. It takes the trainer instance and result dict as
            arguments, and may mutate the result dict as needed.
        collect_metrics_fn (func): override the method used to collect metrics.
            It takes the trainer instance as argumnt.
        before_evaluate_fn (func): callback to run before evaluation. This
            takes the trainer instance as argument.
        mixins (list): list of any class mixins for the returned trainer class.
            These mixins will be applied in order and will have higher
            precedence than the Trainer class

    Returns:
        a Trainer instance that uses the specified args.
    """

    original_kwargs = locals().copy()
    base = add_mixins(Trainer, mixins)

    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)
            if get_initial_state:
                self.state = get_initial_state(self)
            else:
                self.state = {}
            if get_policy_class is None:
                policy = default_policy
            else:
                policy = get_policy_class(config)
            if before_init:
                before_init(self)
            if make_workers:
                self.workers = make_workers(self, env_creator, policy, config)
            else:
                self.workers = self._make_workers(env_creator, policy, config,
                                                  self.config["num_workers"])
            if make_policy_optimizer:
                self.optimizer = make_policy_optimizer(self.workers, config)
            else:
                optimizer_config = dict(
                    config["optimizer"],
                    **{"train_batch_size": config["train_batch_size"]})
                self.optimizer = SyncSamplesOptimizer(self.workers,
                                                      **optimizer_config)
            # self.optimizer: <Override_ray.sync_replay_optimizer.SyncReplayOptimizer object at 0x7f7424799d90>

            if after_init:
                after_init(self)

        @override(Trainer)
        def _train(self, attention_score_dic=None):
            """

            :param attention_score_dic: Call from trainable
            :return:
            """
            if before_train_step:
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            while True:
                '''
                The network is trained here.
                '''
                fetches = self.optimizer.step(attention_score_dic)
                if after_optimizer_step:
                    '''
                    Judge to update target network; after_optimizer_step=update_target_if_needed
                    '''
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    '''
                    Judge to finish the iteration
                    '''
                    break
            '''
            Refine the results
            '''
            if collect_metrics_fn:  # collect_metrics_fn=collect_metrics
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                       prev_steps,
                       info=res.get("info", {}))

            if after_train_result:  # after_train_result=add_trainer_metrics
                after_train_result(self, res)
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["trainer_state"] = self.state.copy()
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.state = state["trainer_state"].copy()

    @staticmethod
    def with_updates(**overrides):
        """Build a copy of this trainer with the specified overrides.

        Arguments:
            overrides (dict): use this to override any of the arguments
                originally passed to build_trainer() for this policy.
        """
        return build_trainer(**dict(original_kwargs, **overrides))

    trainer_cls.with_updates = with_updates
    trainer_cls.__name__ = name
    trainer_cls.__qualname__ = name
    return trainer_cls