Пример #1
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        if loggers is not None:
            # Most users won't run into this as `tune.run()` does not pass
            # the argument anymore. However, we will want to inform users
            # if they instantiate their `Experiment` objects themselves.
            raise ValueError(
                "Passing `loggers` to an `Experiment` is deprecated. Use "
                "an `ExperimentLogger` callback instead, e.g. by passing the "
                "`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
                "passing this as part of the `callback` parameter to "
                "`tune.run()`.")

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier

        # If the name has been set explicitly, we don't want to create
        # dated directories. The same is true for string run identifiers.
        if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
           or isinstance(run, str):
            self.dir_name = self.name
        else:
            self.dir_name = "{}_{}".format(self.name, date_str())

        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir,
                                                      self.dir_name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run":
            self._run_identifier,
            "stop":
            stopping_criteria,
            "config":
            config,
            "resources_per_trial":
            resources_per_trial,
            "num_samples":
            num_samples,
            "local_dir":
            os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir":
            upload_dir,
            "remote_checkpoint_dir":
            self.remote_checkpoint_dir,
            "trial_name_creator":
            trial_name_creator,
            "trial_dirname_creator":
            trial_dirname_creator,
            "loggers":
            loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver":
            sync_to_driver,
            "checkpoint_freq":
            checkpoint_freq,
            "checkpoint_at_end":
            checkpoint_at_end,
            "sync_on_checkpoint":
            sync_on_checkpoint,
            "keep_checkpoints_num":
            keep_checkpoints_num,
            "checkpoint_score_attr":
            checkpoint_score_attr,
            "export_formats":
            export_formats or [],
            "max_failures":
            max_failures,
            "restore":
            os.path.abspath(os.path.expanduser(restore)) if restore else None
        }
        self.spec = spec
Пример #2
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier
        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir": upload_dir,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "loggers": loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver": sync_to_driver,
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "sync_on_checkpoint": sync_on_checkpoint,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore else None
        }
        self.spec = spec
Пример #3
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 sync_to_cloud=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and not inspect.isclass(run) and \
                detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier

        # If the name has been set explicitly, we don't want to create
        # dated directories. The same is true for string run identifiers.
        if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
           or isinstance(run, str):
            self.dir_name = self.name
        else:
            self.dir_name = "{}_{}".format(self.name, date_str())

        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir,
                                                      self.dir_name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, list):
            if any(not isinstance(s, Stopper) for s in stop):
                raise ValueError(
                    "If you pass a list as the `stop` argument to "
                    "`tune.run()`, each element must be an instance of "
                    "`tune.stopper.Stopper`.")
            self._stopper = CombinedStopper(*stop)
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self.is_durable_trainable, sync_to_driver,
                          upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir": upload_dir,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver": sync_to_driver,
            "sync_to_cloud": sync_to_cloud,
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "sync_on_checkpoint": sync_on_checkpoint,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore else None
        }
        self.spec = spec
Пример #4
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None,
                 repeat=None,
                 trial_resources=None,
                 sync_function=None):
        """Initialize a new Experiment.

        The args here take the same meaning as the command line flags defined
        in `tune.py:run`.
        """
        if repeat:
            _raise_deprecation_note("repeat", "num_samples", soft=False)
        if trial_resources:
            _raise_deprecation_note("trial_resources",
                                    "resources_per_trial",
                                    soft=False)
        if sync_function:
            _raise_deprecation_note("sync_function",
                                    "sync_to_driver",
                                    soft=False)

        config = config or {}
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier
        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        spec = {
            "run":
            self._run_identifier,
            "stop":
            stopping_criteria,
            "config":
            config,
            "resources_per_trial":
            resources_per_trial,
            "num_samples":
            num_samples,
            "local_dir":
            os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir":
            upload_dir,
            "remote_checkpoint_dir":
            self.remote_checkpoint_dir,
            "trial_name_creator":
            trial_name_creator,
            "loggers":
            loggers,
            "sync_to_driver":
            sync_to_driver,
            "checkpoint_freq":
            checkpoint_freq,
            "checkpoint_at_end":
            checkpoint_at_end,
            "sync_on_checkpoint":
            sync_on_checkpoint,
            "keep_checkpoints_num":
            keep_checkpoints_num,
            "checkpoint_score_attr":
            checkpoint_score_attr,
            "export_formats":
            export_formats or [],
            "max_failures":
            max_failures,
            "restore":
            os.path.abspath(os.path.expanduser(restore)) if restore else None
        }
        self.spec = spec
Пример #5
0
    def __init__(
        self,
        name,
        run,
        stop=None,
        time_budget_s=None,
        config=None,
        resources_per_trial=None,
        num_samples=1,
        local_dir=None,
        _experiment_checkpoint_dir: Optional[str] = None,
        sync_config=None,
        trial_name_creator=None,
        trial_dirname_creator=None,
        log_to_file=False,
        checkpoint_freq=0,
        checkpoint_at_end=False,
        keep_checkpoints_num=None,
        checkpoint_score_attr=None,
        export_formats=None,
        max_failures=0,
        restore=None,
    ):

        local_dir = _get_local_dir_with_expand_user(local_dir)
        # `_experiment_checkpoint_dir` is for internal use only for better
        # support of Tuner API.
        # If set, it should be a subpath under `local_dir`. Also deduce `dir_name`.
        self._experiment_checkpoint_dir = _experiment_checkpoint_dir
        if _experiment_checkpoint_dir:
            experiment_checkpoint_dir_path = Path(_experiment_checkpoint_dir)
            local_dir_path = Path(local_dir)
            assert local_dir_path in experiment_checkpoint_dir_path.parents
            # `dir_name` is set by `_experiment_checkpoint_dir` indirectly.
            self.dir_name = os.path.relpath(_experiment_checkpoint_dir, local_dir)

        config = config or {}
        sync_config = sync_config or SyncConfig()
        if (
            callable(run)
            and not inspect.isclass(run)
            and detect_checkpoint_function(run)
        ):
            if checkpoint_at_end:
                raise ValueError(
                    "'checkpoint_at_end' cannot be used with a "
                    "checkpointable function. You can specify "
                    "and register checkpoints within "
                    "your trainable function."
                )
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function."
                )
        try:
            self._run_identifier = Experiment.register_if_needed(run)
        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
                raise TuneError(
                    f"The Trainable/training function is too large for grpc resource "
                    f"limit. Check that its definition is not implicitly capturing a "
                    f"large array or other object in scope. "
                    f"Tip: use tune.with_parameters() to put large objects "
                    f"in the Ray object store. \n"
                    f"Original exception: {traceback.format_exc()}"
                )
            else:
                raise e

        self.name = name or self._run_identifier

        if not _experiment_checkpoint_dir:
            self.dir_name = _get_dir_name(run, name, self.name)

        assert self.dir_name

        if sync_config.upload_dir:
            self.remote_checkpoint_dir = os.path.join(
                sync_config.upload_dir, self.dir_name
            )
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, list):
            bad_stoppers = [s for s in stop if not isinstance(s, Stopper)]
            if bad_stoppers:
                stopper_types = [type(s) for s in stop]
                raise ValueError(
                    "If you pass a list as the `stop` argument to "
                    "`tune.run()`, each element must be an instance of "
                    f"`tune.stopper.Stopper`. Got {stopper_types}."
                )
            self._stopper = CombinedStopper(*stop)
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif isinstance(stop, Stopper):
                self._stopper = stop
            else:
                raise ValueError(
                    "Provided stop object must be either a dict, "
                    "a function, or a subclass of "
                    f"`ray.tune.Stopper`. Got {type(stop)}."
                )
        else:
            raise ValueError(
                f"Invalid stop criteria: {stop}. Must be a "
                f"callable or dict. Got {type(stop)}."
            )

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(
                    self._stopper, TimeoutStopper(time_budget_s)
                )
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "time_budget_s": time_budget_s,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": local_dir,
            "sync_config": sync_config,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "log_to_file": (stdout_file, stderr_file),
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore
            else None,
        }
        self.spec = spec