示例#1
0
def test_timeout_stopper_timeout():
    with freeze_time() as frozen:
        stopper = TimeoutStopper(timeout=60)
        assert not stopper.stop_all()
        frozen.tick(40)
        assert not stopper.stop_all()
        frozen.tick(22)
        assert stopper.stop_all()
示例#2
0
    def testTimeout(self):
        from ray.tune.stopper import TimeoutStopper
        import datetime

        def train(config):
            for i in range(20):
                tune.report(metric=i)
                time.sleep(1)

        register_trainable("f1", train)

        start = time.time()
        tune.run("f1", time_budget_s=5)
        diff = time.time() - start
        self.assertLess(diff, 10)

        # Metric should fire first
        start = time.time()
        tune.run("f1", stop={"metric": 3}, time_budget_s=7)
        diff = time.time() - start
        self.assertLess(diff, 7)

        # Timeout should fire first
        start = time.time()
        tune.run("f1", stop={"metric": 10}, time_budget_s=5)
        diff = time.time() - start
        self.assertLess(diff, 10)

        # Combined stopper. Shorter timeout should win.
        start = time.time()
        tune.run("f1",
                 stop=TimeoutStopper(10),
                 time_budget_s=datetime.timedelta(seconds=3))
        diff = time.time() - start
        self.assertLess(diff, 9)
示例#3
0
def test_timeout_stopper_recover_before_timeout():
    """ "If checkpointed before timeout, should continue where we left."""
    with freeze_time() as frozen:
        stopper = TimeoutStopper(timeout=60)
        assert not stopper.stop_all()
        frozen.tick(40)
        assert not stopper.stop_all()
        checkpoint = pickle.dumps(stopper)

        # Continue sometime in the future. This is after start_time + timeout
        # but we should still continue training.
        frozen.tick(200)

        # Continue, so we shouldn't time out
        stopper = pickle.loads(checkpoint)
        assert not stopper.stop_all()
        frozen.tick(10)
        assert not stopper.stop_all()
        frozen.tick(12)
        assert stopper.stop_all()
示例#4
0
def test_timeout_stopper_recover_after_timeout():
    """ "If checkpointed after timeout, should still stop after recover."""
    with freeze_time() as frozen:
        stopper = TimeoutStopper(timeout=60)
        assert not stopper.stop_all()
        frozen.tick(62)
        assert stopper.stop_all()
        checkpoint = pickle.dumps(stopper)

        # Continue sometime in the future
        frozen.tick(200)

        # Continue, so we should still time out.
        stopper = pickle.loads(checkpoint)
        assert stopper.stop_all()
        frozen.tick(10)
        assert stopper.stop_all()
示例#5
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        if loggers is not None:
            # Most users won't run into this as `tune.run()` does not pass
            # the argument anymore. However, we will want to inform users
            # if they instantiate their `Experiment` objects themselves.
            raise ValueError(
                "Passing `loggers` to an `Experiment` is deprecated. Use "
                "an `ExperimentLogger` callback instead, e.g. by passing the "
                "`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
                "passing this as part of the `callback` parameter to "
                "`tune.run()`.")

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier

        # If the name has been set explicitly, we don't want to create
        # dated directories. The same is true for string run identifiers.
        if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
           or isinstance(run, str):
            self.dir_name = self.name
        else:
            self.dir_name = "{}_{}".format(self.name, date_str())

        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir,
                                                      self.dir_name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run":
            self._run_identifier,
            "stop":
            stopping_criteria,
            "config":
            config,
            "resources_per_trial":
            resources_per_trial,
            "num_samples":
            num_samples,
            "local_dir":
            os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir":
            upload_dir,
            "remote_checkpoint_dir":
            self.remote_checkpoint_dir,
            "trial_name_creator":
            trial_name_creator,
            "trial_dirname_creator":
            trial_dirname_creator,
            "loggers":
            loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver":
            sync_to_driver,
            "checkpoint_freq":
            checkpoint_freq,
            "checkpoint_at_end":
            checkpoint_at_end,
            "sync_on_checkpoint":
            sync_on_checkpoint,
            "keep_checkpoints_num":
            keep_checkpoints_num,
            "checkpoint_score_attr":
            checkpoint_score_attr,
            "export_formats":
            export_formats or [],
            "max_failures":
            max_failures,
            "restore":
            os.path.abspath(os.path.expanduser(restore)) if restore else None
        }
        self.spec = spec
示例#6
0
文件: experiment.py 项目: zzmcdc/ray
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier
        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir": upload_dir,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "loggers": loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver": sync_to_driver,
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "sync_on_checkpoint": sync_on_checkpoint,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore else None
        }
        self.spec = spec
示例#7
0
文件: experiment.py 项目: eggie5/ray
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 sync_to_cloud=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and not inspect.isclass(run) and \
                detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier

        # If the name has been set explicitly, we don't want to create
        # dated directories. The same is true for string run identifiers.
        if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
           or isinstance(run, str):
            self.dir_name = self.name
        else:
            self.dir_name = "{}_{}".format(self.name, date_str())

        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir,
                                                      self.dir_name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, list):
            if any(not isinstance(s, Stopper) for s in stop):
                raise ValueError(
                    "If you pass a list as the `stop` argument to "
                    "`tune.run()`, each element must be an instance of "
                    "`tune.stopper.Stopper`.")
            self._stopper = CombinedStopper(*stop)
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self.is_durable_trainable, sync_to_driver,
                          upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir": upload_dir,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver": sync_to_driver,
            "sync_to_cloud": sync_to_cloud,
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "sync_on_checkpoint": sync_on_checkpoint,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore else None
        }
        self.spec = spec
示例#8
0
    def __init__(
        self,
        name,
        run,
        stop=None,
        time_budget_s=None,
        config=None,
        resources_per_trial=None,
        num_samples=1,
        local_dir=None,
        _experiment_checkpoint_dir: Optional[str] = None,
        sync_config=None,
        trial_name_creator=None,
        trial_dirname_creator=None,
        log_to_file=False,
        checkpoint_freq=0,
        checkpoint_at_end=False,
        keep_checkpoints_num=None,
        checkpoint_score_attr=None,
        export_formats=None,
        max_failures=0,
        restore=None,
    ):

        local_dir = _get_local_dir_with_expand_user(local_dir)
        # `_experiment_checkpoint_dir` is for internal use only for better
        # support of Tuner API.
        # If set, it should be a subpath under `local_dir`. Also deduce `dir_name`.
        self._experiment_checkpoint_dir = _experiment_checkpoint_dir
        if _experiment_checkpoint_dir:
            experiment_checkpoint_dir_path = Path(_experiment_checkpoint_dir)
            local_dir_path = Path(local_dir)
            assert local_dir_path in experiment_checkpoint_dir_path.parents
            # `dir_name` is set by `_experiment_checkpoint_dir` indirectly.
            self.dir_name = os.path.relpath(_experiment_checkpoint_dir, local_dir)

        config = config or {}
        sync_config = sync_config or SyncConfig()
        if (
            callable(run)
            and not inspect.isclass(run)
            and detect_checkpoint_function(run)
        ):
            if checkpoint_at_end:
                raise ValueError(
                    "'checkpoint_at_end' cannot be used with a "
                    "checkpointable function. You can specify "
                    "and register checkpoints within "
                    "your trainable function."
                )
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function."
                )
        try:
            self._run_identifier = Experiment.register_if_needed(run)
        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
                raise TuneError(
                    f"The Trainable/training function is too large for grpc resource "
                    f"limit. Check that its definition is not implicitly capturing a "
                    f"large array or other object in scope. "
                    f"Tip: use tune.with_parameters() to put large objects "
                    f"in the Ray object store. \n"
                    f"Original exception: {traceback.format_exc()}"
                )
            else:
                raise e

        self.name = name or self._run_identifier

        if not _experiment_checkpoint_dir:
            self.dir_name = _get_dir_name(run, name, self.name)

        assert self.dir_name

        if sync_config.upload_dir:
            self.remote_checkpoint_dir = os.path.join(
                sync_config.upload_dir, self.dir_name
            )
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, list):
            bad_stoppers = [s for s in stop if not isinstance(s, Stopper)]
            if bad_stoppers:
                stopper_types = [type(s) for s in stop]
                raise ValueError(
                    "If you pass a list as the `stop` argument to "
                    "`tune.run()`, each element must be an instance of "
                    f"`tune.stopper.Stopper`. Got {stopper_types}."
                )
            self._stopper = CombinedStopper(*stop)
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif isinstance(stop, Stopper):
                self._stopper = stop
            else:
                raise ValueError(
                    "Provided stop object must be either a dict, "
                    "a function, or a subclass of "
                    f"`ray.tune.Stopper`. Got {type(stop)}."
                )
        else:
            raise ValueError(
                f"Invalid stop criteria: {stop}. Must be a "
                f"callable or dict. Got {type(stop)}."
            )

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(
                    self._stopper, TimeoutStopper(time_budget_s)
                )
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "time_budget_s": time_budget_s,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": local_dir,
            "sync_config": sync_config,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "log_to_file": (stdout_file, stderr_file),
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore
            else None,
        }
        self.spec = spec