Пример #1
0
 def _log_result(self, result):
     if self.config["callbacks"].get("on_train_result"):
         self.config["callbacks"]["on_train_result"]({
             "trainer": self,
             "result": result,
         })
     # log after the callback is invoked, so that the user has a chance
     # to mutate the result
     Trainable._log_result(self, result)
Пример #2
0
 def _log_result(self, result):
     if self.config["callbacks"].get("on_train_result"):
         self.config["callbacks"]["on_train_result"]({
             "agent": self,
             "result": result,
         })
     # log after the callback is invoked, so that the user has a chance
     # to mutate the result
     Trainable._log_result(self, result)
Пример #3
0
    def __init__(self, config=None, env=None, logger_creator=None):
        """Initialize an RLLib agent.

        Args:
            config (dict): Algorithm-specific configuration data.
            env (str): Name of the environment to use. Note that this can also
                be specified as the `env` key in config.
            logger_creator (func): Function that creates a ray.tune.Logger
                object. If unspecified, a default logger is created.
        """

        config = config or {}

        # Agents allow env ids to be passed directly to the constructor.
        self._env_id = env or config.get("env")
        Trainable.__init__(self, config, logger_creator)
Пример #4
0
    def train(self) -> ResultDict:
        """Overrides super.train to synchronize global vars."""

        result = None
        for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
            try:
                result = Trainable.train(self)
            except RayError as err:
                if self.config["ignore_worker_failures"]:
                    logger.exception(
                        "Error in train call, attempting to recover")
                    self._try_recover()
                else:
                    logger.info(
                        "Worker crashed during call to train(). To attempt to "
                        "continue training without the failed worker, set "
                        "`'ignore_worker_failures': True`.")
                    raise err
            except Exception as exc:
                time.sleep(0.5)  # allow logs messages to propagate
                raise exc
            else:
                break
        if result is None:
            raise RuntimeError("Failed to recover from worker crash")

        if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
            self._sync_filters_if_needed(self.workers)

        return result
Пример #5
0
    def train(self):
        """Overrides super.train to synchronize global vars."""

        if hasattr(self, "optimizer") and isinstance(self.optimizer,
                                                     PolicyOptimizer):
            self.global_vars["timestep"] = self.optimizer.num_steps_sampled
            self.optimizer.local_evaluator.set_global_vars(self.global_vars)
            for ev in self.optimizer.remote_evaluators:
                ev.set_global_vars.remote(self.global_vars)
            logger.debug("updated global vars: {}".format(self.global_vars))

        if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
                and hasattr(self, "local_evaluator")):
            FilterManager.synchronize(
                self.local_evaluator.filters,
                self.remote_evaluators,
                update_remote=self.config["synchronize_filters"])
            logger.debug("synchronized filters: {}".format(
                self.local_evaluator.filters))

        result = Trainable.train(self)
        if self.config["callbacks"].get("on_train_result"):
            self.config["callbacks"]["on_train_result"]({
                "agent": self,
                "result": result,
            })
        return result
Пример #6
0
    def train(self):
        """Overrides super.train to synchronize global vars."""

        if hasattr(self, "optimizer") and isinstance(self.optimizer,
                                                     PolicyOptimizer):
            self.global_vars["timestep"] = self.optimizer.num_steps_sampled
            self.optimizer.local_evaluator.set_global_vars(self.global_vars)
            for ev in self.optimizer.remote_evaluators:
                ev.set_global_vars.remote(self.global_vars)
            logger.debug("updated global vars: {}".format(self.global_vars))

        if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
                and hasattr(self, "local_evaluator")):
            FilterManager.synchronize(
                self.local_evaluator.filters,
                self.remote_evaluators,
                update_remote=self.config["synchronize_filters"])
            logger.debug("synchronized filters: {}".format(
                self.local_evaluator.filters))

        result = Trainable.train(self)
        if self.config["callbacks"].get("on_train_result"):
            self.config["callbacks"]["on_train_result"]({
                "agent": self,
                "result": result,
            })
        return result
Пример #7
0
    def __init__(self, config=None, env=None, logger_creator=None):
        """Initialize an RLLib trainer.

        Args:
            config (dict): Algorithm-specific configuration data.
            env (str): Name of the environment to use. Note that this can also
                be specified as the `env` key in config.
            logger_creator (func): Function that creates a ray.tune.Logger
                object. If unspecified, a default logger is created.
        """

        config = config or {}

        if tf and config.get("eager"):
            tf.enable_eager_execution()
            logger.info("Executing eagerly, with eager_tracing={}".format(
                "True" if config.get("eager_tracing") else "False"))

        if tf and not tf.executing_eagerly():
            logger.info("Tip: set 'eager': true or the --eager flag to enable "
                        "TensorFlow eager execution")

        # Vars to synchronize to workers on each train call
        self.global_vars = {"timestep": 0}

        # Trainers allow env ids to be passed directly to the constructor.
        self._env_id = self._register_if_needed(env or config.get("env"))

        # Create a default logger creator if no logger_creator is specified
        if logger_creator is None:
            timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
            logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
                                              timestr)

            def default_logger_creator(config):
                """Creates a Unified logger with a default logdir prefix
                containing the agent name and the env id
                """
                if not os.path.exists(DEFAULT_RESULTS_DIR):
                    os.makedirs(DEFAULT_RESULTS_DIR)
                logdir = tempfile.mkdtemp(
                    prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
                return UnifiedLogger(config, logdir, loggers=None)

            logger_creator = default_logger_creator

        Trainable.__init__(self, config, logger_creator)
Пример #8
0
    def __init__(
            self, config={}, env=None, registry=get_registry(),
            logger_creator=None):
        """Initialize an RLLib agent.

        Args:
            config (dict): Algorithm-specific configuration data.
            env (str): Name of the environment to use. Note that this can also
                be specified as the `env` key in config.
            registry (obj): Object registry for user-defined envs, models, etc.
                If unspecified, the default registry will be used.
            logger_creator (func): Function that creates a ray.tune.Logger
                object. If unspecified, a default logger is created.
        """

        # Agents allow env ids to be passed directly to the constructor.
        self._env_id = env or config.get("env")
        Trainable.__init__(self, config, registry, logger_creator)
Пример #9
0
    def train(self):
        """Overrides super.train to synchronize global vars."""

        if self._has_policy_optimizer():
            self.global_vars["timestep"] = self.optimizer.num_steps_sampled
            self.optimizer.workers.local_worker().set_global_vars(
                self.global_vars)
            for w in self.optimizer.workers.remote_workers():
                w.set_global_vars.remote(self.global_vars)
            logger.debug("updated global vars: {}".format(self.global_vars))

        result = None
        for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
            try:
                result = Trainable.train(self)
            except RayError as e:
                if self.config["ignore_worker_failures"]:
                    logger.exception(
                        "Error in train call, attempting to recover")
                    self._try_recover()
                else:
                    logger.info(
                        "Worker crashed during call to train(). To attempt to "
                        "continue training without the failed worker, set "
                        "`'ignore_worker_failures': True`.")
                    raise e
            except Exception as e:
                time.sleep(0.5)  # allow logs messages to propagate
                raise e
            else:
                break
        if result is None:
            raise RuntimeError("Failed to recover from worker crash")

        if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
                and hasattr(self, "workers")
                and isinstance(self.workers, WorkerSet)):
            FilterManager.synchronize(
                self.workers.local_worker().filters,
                self.workers.remote_workers(),
                update_remote=self.config["synchronize_filters"])
            logger.debug("synchronized filters: {}".format(
                self.workers.local_worker().filters))

        if self._has_policy_optimizer():
            result["num_healthy_workers"] = len(
                self.optimizer.workers.remote_workers())

        if self.config["evaluation_interval"]:
            if self._iteration % self.config["evaluation_interval"] == 0:
                evaluation_metrics = self._evaluate()
                assert isinstance(evaluation_metrics, dict), \
                    "_evaluate() needs to return a dict."
                result.update(evaluation_metrics)

        return result
Пример #10
0
    def train(self):
        """Overrides super.train to synchronize global vars."""

        if hasattr(self, "optimizer") and isinstance(self.optimizer,
                                                     PolicyOptimizer):
            self.global_vars["timestep"] = self.optimizer.num_steps_sampled
            self.optimizer.local_evaluator.set_global_vars(self.global_vars)
            for ev in self.optimizer.remote_evaluators:
                ev.set_global_vars.remote(self.global_vars)

        return Trainable.train(self)
Пример #11
0
    def __init__(
            self, config=None, env=None, registry=None,
            logger_creator=None):
        """Initialize an RLLib agent.

        Args:
            config (dict): Algorithm-specific configuration data.
            env (str): Name of the environment to use. Note that this can also
                be specified as the `env` key in config.
            registry (obj): Object registry for user-defined envs, models, etc.
                If unspecified, the default registry will be used.
            logger_creator (func): Function that creates a ray.tune.Logger
                object. If unspecified, a default logger is created.
        """

        config = config or {}

        # Agents allow env ids to be passed directly to the constructor.
        self._env_id = env or config.get("env")
        Trainable.__init__(self, config, registry, logger_creator)
Пример #12
0
    def __init__(self, config=None, env=None, logger_creator=None):
        """Initialize an RLLib agent.

        Args:
            config (dict): Algorithm-specific configuration data.
            env (str): Name of the environment to use. Note that this can also
                be specified as the `env` key in config.
            logger_creator (func): Function that creates a ray.tune.Logger
                object. If unspecified, a default logger is created.
        """

        config = config or {}
        Agent._validate_config(config)

        # Vars to synchronize to evaluators on each train call
        self.global_vars = {"timestep": 0}

        # Agents allow env ids to be passed directly to the constructor.
        self._env_id = env or config.get("env")

        # Create a default logger creator if no logger_creator is specified
        if logger_creator is None:
            timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
            logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id,
                                              timestr)

            def default_logger_creator(config):
                """Creates a Unified logger with a default logdir prefix
                containing the agent name and the env id
                """
                if not os.path.exists(DEFAULT_RESULTS_DIR):
                    os.makedirs(DEFAULT_RESULTS_DIR)
                logdir = tempfile.mkdtemp(prefix=logdir_prefix,
                                          dir=DEFAULT_RESULTS_DIR)
                return UnifiedLogger(config, logdir, None)

            logger_creator = default_logger_creator

        Trainable.__init__(self, config, logger_creator)
Пример #13
0
    def __init__(self, config=None, env=None, logger_creator=None):
        """Initialize an RLLib agent.

        Args:
            config (dict): Algorithm-specific configuration data.
            env (str): Name of the environment to use. Note that this can also
                be specified as the `env` key in config.
            logger_creator (func): Function that creates a ray.tune.Logger
                object. If unspecified, a default logger is created.
        """

        config = config or {}

        # Vars to synchronize to evaluators on each train call
        self.global_vars = {"timestep": 0}

        # Agents allow env ids to be passed directly to the constructor.
        self._env_id = self._register_if_needed(env or config.get("env"))

        # Create a default logger creator if no logger_creator is specified
        if logger_creator is None:
            timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
            logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id,
                                              timestr)

            def default_logger_creator(config):
                """Creates a Unified logger with a default logdir prefix
                containing the agent name and the env id
                """
                if not os.path.exists(DEFAULT_RESULTS_DIR):
                    os.makedirs(DEFAULT_RESULTS_DIR)
                logdir = tempfile.mkdtemp(
                    prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
                return UnifiedLogger(config, logdir, None)

            logger_creator = default_logger_creator

        Trainable.__init__(self, config, logger_creator)
Пример #14
0
    def train(self):
        """Overrides super.train to synchronize global vars."""

        if hasattr(self, "optimizer") and isinstance(self.optimizer,
                                                     PolicyOptimizer):
            self.global_vars["timestep"] = self.optimizer.num_steps_sampled
            self.optimizer.local_evaluator.set_global_vars(self.global_vars)
            for ev in self.optimizer.remote_evaluators:
                ev.set_global_vars.remote(self.global_vars)

        if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
                and hasattr(self, "local_evaluator")):
            FilterManager.synchronize(
                self.local_evaluator.filters,
                self.remote_evaluators,
                update_remote=self.config["synchronize_filters"])

        return Trainable.train(self)
Пример #15
0
Файл: agent.py Проект: zhy52/ray
    def train(self):
        """Overrides super.train to synchronize global vars."""

        if self._has_policy_optimizer():
            self.global_vars["timestep"] = self.optimizer.num_steps_sampled
            self.optimizer.local_evaluator.set_global_vars(self.global_vars)
            for ev in self.optimizer.remote_evaluators:
                ev.set_global_vars.remote(self.global_vars)
            logger.debug("updated global vars: {}".format(self.global_vars))

        result = None
        for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
            try:
                result = Trainable.train(self)
            except RayError as e:
                if self.config["ignore_worker_failures"]:
                    logger.exception(
                        "Error in train call, attempting to recover")
                    self._try_recover()
                else:
                    logger.info(
                        "Worker crashed during call to train(). To attempt to "
                        "continue training without the failed worker, set "
                        "`'ignore_worker_failures': True`.")
                    raise e
            else:
                break
        if result is None:
            raise RuntimeError("Failed to recover from worker crash")

        if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
                and hasattr(self, "local_evaluator")):
            FilterManager.synchronize(
                self.local_evaluator.filters,
                self.remote_evaluators,
                update_remote=self.config["synchronize_filters"])
            logger.debug("synchronized filters: {}".format(
                self.local_evaluator.filters))

        if self._has_policy_optimizer():
            result["num_healthy_workers"] = len(
                self.optimizer.remote_evaluators)
        return result
Пример #16
0
 def _log_result(self, result):
     self.callbacks.on_train_result(trainer=self, result=result)
     # log after the callback is invoked, so that the user has a chance
     # to mutate the result
     Trainable._log_result(self, result)