Ejemplo n.º 1
0
    def _sync_trial_checkpoint(self, trial: "Trial",
                               checkpoint: _TuneCheckpoint):
        if checkpoint.storage == _TuneCheckpoint.MEMORY:
            return

        trial_syncer = self._get_trial_syncer(trial)
        # If the sync_function is False, syncing to driver is disabled.
        # In every other case (valid values include None, True Callable,
        # NodeSyncer) syncing to driver is enabled.
        if trial.sync_on_checkpoint and self._sync_function is not False:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                trial_syncer.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(f"Trial {trial}: An error occurred during the "
                             f"checkpoint pre-sync wait: {e}")
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if trial_syncer.sync_down():
                    trial_syncer.wait()
                else:
                    logger.error(f"Trial {trial}: Checkpoint sync skipped. "
                                 f"This should not happen.")
            except TuneError as e:
                if trial.uses_cloud_checkpointing:
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error(f"Trial {trial}: Sync error: {e}")
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not trial.uses_cloud_checkpointing:
                if not os.path.exists(checkpoint.value):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down. "
                                    "Are you running on a Kubernetes or "
                                    "managed cluster? rsync will not function "
                                    "due to a lack of SSH functionality. "
                                    "You'll need to use cloud-checkpointing "
                                    "if that's the case, see instructions "
                                    "here: {} .".format(
                                        trial, checkpoint.value,
                                        CLOUD_CHECKPOINTING_URL))
Ejemplo n.º 2
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    try:
        args, _ = parser.parse_known_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)
    if "resources_per_trial" in spec:
        trial_kwargs["resources"] = json_to_resources(
            spec["resources_per_trial"])
    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(spec["local_dir"], output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_at_end=args.checkpoint_at_end,
        sync_on_checkpoint=args.sync_on_checkpoint,
        keep_checkpoints_num=args.keep_checkpoints_num,
        checkpoint_score_attr=args.checkpoint_score_attr,
        export_formats=spec.get("export_formats", []),
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        trial_name_creator=spec.get("trial_name_creator"),
        trial_dirname_creator=spec.get("trial_dirname_creator"),
        loggers=spec.get("loggers"),
        log_to_file=spec.get("log_to_file"),
        # str(None) doesn't create None
        max_failures=args.max_failures,
        **trial_kwargs)
Ejemplo n.º 3
0
    def wait_for_all(self):
        failed_syncs = {}
        for trial, sync_process in self._sync_processes.items():
            try:
                sync_process.wait()
            except Exception as e:
                failed_syncs[trial] = e

        if failed_syncs:
            sync_str = "\n".join(
                [f"  {trial}: {e}" for trial, e in failed_syncs.items()]
            )
            raise TuneError(
                f"At least one trial failed to sync down when waiting for all "
                f"trials to sync: \n{sync_str}"
            )
Ejemplo n.º 4
0
    def validate(formats):
        """Validates formats.

        Raises:
            ValueError if the format is unknown.
        """
        for i in range(len(formats)):
            formats[i] = formats[i].strip().lower()
            if formats[i] not in [
                    ExportFormat.CHECKPOINT,
                    ExportFormat.MODEL,
                    ExportFormat.ONNX,
                    ExportFormat.H5,
            ]:
                raise TuneError("Unsupported import/export format: " +
                                formats[i])
Ejemplo n.º 5
0
    def on_checkpoint(
        self,
        iteration: int,
        trials: List["Trial"],
        trial: "Trial",
        checkpoint: _TrackedCheckpoint,
        **info,
    ):
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            return

        if self._sync_trial_dir(
                trial, force=trial.sync_on_checkpoint,
                wait=True) and not os.path.exists(checkpoint.dir_or_data):
            raise TuneError(
                f"Trial {trial}: Checkpoint path {checkpoint.dir_or_data} not "
                "found after successful sync down.")
Ejemplo n.º 6
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        with warn_if_slow("callbacks.on_step_begin"):
            self._callbacks.on_step_begin(
                iteration=self._iteration, trials=self._trials)
        next_trial = self._get_next_trial()  # blocking
        if next_trial is not None:
            with warn_if_slow("start_trial"):
                self.trial_executor.start_trial(next_trial)
                self._callbacks.on_trial_start(
                    iteration=self._iteration,
                    trials=self._trials,
                    trial=next_trial)
        elif self.trial_executor.get_running_trials():
            self._process_events()  # blocking
        else:
            self.trial_executor.on_no_available_trials(self)

        self._stop_experiment_if_needed()

        try:
            with warn_if_slow("experiment_checkpoint"):
                self.checkpoint()
        except Exception as e:
            logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_stop_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
        with warn_if_slow("callbacks.on_step_end"):
            self._callbacks.on_step_end(
                iteration=self._iteration, trials=self._trials)
Ejemplo n.º 7
0
def _try_resolve(v) -> Tuple[bool, Any]:
    if isinstance(v, Domain):
        # Domain to sample from
        return False, v
    elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
        # Lambda function in eval syntax
        return False, Function(
            lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec}))
    elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
        # Grid search values
        grid_values = v["grid_search"]
        if not isinstance(grid_values, list):
            raise TuneError(
                "Grid search expected list of values, got: {}".format(
                    grid_values))
        return False, Categorical(grid_values).grid()
    return True, v
Ejemplo n.º 8
0
def _try_resolve(v):
    if isinstance(v, sample_from):
        # Function to sample from
        return False, v.func
    elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
        # Lambda function in eval syntax
        return False, lambda spec: eval(
            v["eval"], _STANDARD_IMPORTS, {"spec": spec})
    elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
        # Grid search values
        grid_values = v["grid_search"]
        if not isinstance(grid_values, list):
            raise TuneError(
                "Grid search expected list of values, got: {}".format(
                    grid_values))
        return False, grid_values
    return True, v
Ejemplo n.º 9
0
    def _train(self):
        time.sleep(
            self.config.get("script_min_iter_time_s",
                            self._default_config["script_min_iter_time_s"]))
        result = self._status_reporter._get_and_clear_status()
        while result is None:
            time.sleep(1)
            result = self._status_reporter._get_and_clear_status()
        if result.timesteps_total is None:
            raise TuneError("Must specify timesteps_total in result", result)

        result = result._replace(
            timesteps_this_iter=(result.timesteps_total -
                                 self._last_reported_timestep))
        self._last_reported_timestep = result.timesteps_total

        return result
Ejemplo n.º 10
0
 def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
     assert max_retries > 0
     last_error = None
     for _ in range(max_retries - 1):
         try:
             self.wait()
         except Exception as e:
             logger.error(
                 f"Caught sync error: {e}. "
                 f"Retrying after sleeping for {backoff_s} seconds...")
             last_error = e
             time.sleep(backoff_s)
             self.retry()
             continue
         return
     raise TuneError(
         f"Failed sync even after {max_retries} retries.") from last_error
Ejemplo n.º 11
0
    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""
        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            elif isinstance(criteria, dict):
                raise ValueError(
                    "Stopping criteria is now flattened by default. "
                    "Use forward slashes to nest values `key1/key2/key3`.")
            elif result[criteria] >= stop_value:
                return True
        return False
Ejemplo n.º 12
0
    def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint):
        if checkpoint.storage == Checkpoint.MEMORY:
            return

        # Local import to avoid circular dependencies between syncer and
        # trainable
        from ray.tune.durable_trainable import DurableTrainable

        trial_syncer = self._get_trial_syncer(trial)
        # If the sync_function is False, syncing to driver is disabled.
        # In every other case (valid values include None, True Callable,
        # NodeSyncer) syncing to driver is enabled.
        if trial.sync_on_checkpoint and self._sync_function is not False:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                trial_syncer.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(
                    "Trial %s: An error occurred during the "
                    "checkpoint pre-sync wait - %s", trial, str(e))
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if trial_syncer.sync_down():
                    trial_syncer.wait()
                else:
                    logger.error(
                        "Trial %s: Checkpoint sync skipped. "
                        "This should not happen.", trial)
            except TuneError as e:
                if issubclass(trial.get_trainable_cls(), DurableTrainable):
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error("Trial %s: Sync error - %s", trial, str(e))
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not issubclass(trial.get_trainable_cls(), DurableTrainable):
                if not os.path.exists(checkpoint.value):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down.".format(
                                        trial, checkpoint.value))
Ejemplo n.º 13
0
    def on_checkpoint(self, checkpoint):
        """Hook for handling checkpoints taken by the Trainable.

        Args:
            checkpoint (Checkpoint): Checkpoint taken.
        """
        if checkpoint.storage == Checkpoint.MEMORY:
            # TODO(ujvl): Handle this separately to avoid restoration failure.
            self.checkpoint_manager.on_checkpoint(checkpoint)
            return
        if self.sync_on_checkpoint:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                self.result_logger.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(
                    "Trial %s: An error occurred during the "
                    "checkpoint pre-sync wait.", str(e))
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if self.result_logger.sync_down():
                    self.result_logger.wait()
                else:
                    logger.error(
                        "Trial %s: Checkpoint sync skipped. "
                        "This should not happen.", self)
            except TuneError as e:
                if issubclass(self.get_trainable_cls(), DurableTrainable):
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error("Trial %s: Sync error - %s", self, str(e))
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not issubclass(self.get_trainable_cls(), DurableTrainable):
                if not os.path.exists(checkpoint.value):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down.".format(
                                        self, checkpoint.value))
        self.checkpoint_manager.on_checkpoint(checkpoint)
Ejemplo n.º 14
0
def json_to_resources(data):
    if data is None or data == "null":
        return None
    if isinstance(data, string_types):
        data = json.loads(data)
    for k in data:
        if k in ["driver_cpu_limit", "driver_gpu_limit"]:
            raise TuneError(
                "The field `{}` is no longer supported. Use `extra_cpu` "
                "or `extra_gpu` instead.".format(k))
        if k not in Resources._fields:
            raise ValueError(
                "Unknown resource field {}, must be one of {}".format(
                    k, Resources._fields))
    return Resources(data.get("cpu", 1), data.get("gpu", 0),
                     data.get("extra_cpu", 0), data.get("extra_gpu", 0),
                     data.get("custom_resources"),
                     data.get("extra_custom_resources"))
Ejemplo n.º 15
0
    def start_trial(self, trial, checkpoint=None):
        # Reserve node before starting trial
        resources = trial.resources

        # Check for required GPUs first
        required = resources.gpu_total()
        if required > 0:
            resource_attr = "gpu"
        else:
            # No GPU, just use CPU
            resource_attr = "cpu"
            required = resources.cpu_total()

        # Compute nodes required to fulfill trial resources request
        custom_resources = {}
        for node_id, node_resource in self._resources_by_node.items():
            # Compute resource remaining on each node
            node_capacity = node_resource.get_res_total(node_id)
            committed_capacity = self._committed_resources.get_res_total(
                node_id)
            remaining_capacity = node_capacity - committed_capacity
            node_procs = getattr(node_resource, resource_attr, 0)
            available = node_procs * remaining_capacity

            if available == 0:
                continue

            if required <= available:
                custom_resources[node_id] = required / node_procs
                required = 0
                break
            else:
                custom_resources[node_id] = remaining_capacity
                required -= available

        if required > 0:
            raise TuneError(f"Unable to start trial {trial.trial_id}. "
                            f"Not enough nodes")

        # Update trial node affinity configuration and
        # extra_custom_resources requirement
        trial.config.update(__ray_node_affinity__=custom_resources)
        trial.resources.extra_custom_resources.update(custom_resources)
        super().start_trial(trial, checkpoint)
Ejemplo n.º 16
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    try:
        # Special case the `env` param for RLlib by automatically
        # moving it into the `config` section.
        if "env" in spec:
            spec["config"] = spec.get("config", {})
            spec["config"]["env"] = spec["env"]
            del spec["env"]
        args = parser.parse_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)
    if "trial_resources" in spec:
        trial_kwargs["resources"] = json_to_resources(spec["trial_resources"])
    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(args.local_dir, output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        checkpoint_freq=args.checkpoint_freq,
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        upload_dir=args.upload_dir,
        max_failures=args.max_failures,
        **trial_kwargs)
Ejemplo n.º 17
0
def run_experiments(experiments, scheduler=None, **ray_args):
    if scheduler is None:
        scheduler = FIFOScheduler()
    runner = TrialRunner(scheduler)

    for name, spec in experiments.items():
        for trial in generate_trials(spec, name):
            runner.add_trial(trial)
    print(runner.debug_string())

    ray.init(**ray_args)

    while not runner.is_finished():
        runner.step()
        print(runner.debug_string())

    for trial in runner.get_trials():
        if trial.status != Trial.TERMINATED:
            raise TuneError("Trial did not complete", trial)

    return runner.get_trials()
    def train(self):
        if not self._initialize_ok:
            raise ValueError(
                "Agent initialization failed, see previous errors")

        now = time.time()
        time.sleep(self.config["script_min_iter_time_s"])

        result = self._status_reporter._get_and_clear_status()
        while result is None:
            time.sleep(1)
            result = self._status_reporter._get_and_clear_status()
        if result.timesteps_total is None:
            raise TuneError("Must specify timesteps_total in result", result)

        # Include the negative loss to use as a stopping condition
        if result.mean_loss is not None:
            neg_loss = -result.mean_loss
        else:
            neg_loss = result.neg_mean_loss

        result = result._replace(
            experiment_id=self._experiment_id,
            neg_mean_loss=neg_loss,
            training_iteration=self.iteration,
            time_this_iter_s=now - self._last_reported_time,
            timesteps_this_iter=(result.timesteps_total -
                                 self._last_reported_timestep),
            time_total_s=now - self._start_time,
            pid=os.getpid(),
            hostname=os.uname()[1])

        if result.timesteps_total:
            self._last_reported_timestep = result.timesteps_total
        self._last_reported_time = now
        self._iteration += 1

        self._result_logger.on_result(result)

        return result
def import_function(file_path, function_name):
    # strong assumption here that we're in a new process
    file_path = os.path.expanduser(file_path)
    sys.path.insert(0, os.path.dirname(file_path))
    if hasattr(importlib, "util"):
        # Python 3.4+
        spec = importlib.util.spec_from_file_location("external_file",
                                                      file_path)
        external_file = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(external_file)
    elif hasattr(importlib, "machinery"):
        # Python 3.3
        from importlib.machinery import SourceFileLoader
        external_file = SourceFileLoader("external_file",
                                         file_path).load_module()
    else:
        # Python 2.x
        import imp
        external_file = imp.load_source("external_file", file_path)
    if not external_file:
        raise TuneError("Unable to import file at {}".format(file_path))
    return getattr(external_file, function_name)
Ejemplo n.º 20
0
def run_experiments(experiments,
                    scheduler=None,
                    with_server=False,
                    server_port=TuneServer.DEFAULT_PORT,
                    verbose=True):

    # Make sure rllib agents are registered
    from ray import rllib  # noqa # pylint: disable=unused-import

    if scheduler is None:
        scheduler = FIFOScheduler()

    runner = TrialRunner(scheduler,
                         launch_web_server=with_server,
                         server_port=server_port)

    for name, spec in experiments.items():
        for trial in generate_trials(spec, name):
            trial.set_verbose(verbose)
            runner.add_trial(trial)
    print(runner.debug_string(max_debug=99999))

    last_debug = 0
    while not runner.is_finished():
        runner.step()
        if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
            print(runner.debug_string())
            last_debug = time.time()

    print(runner.debug_string(max_debug=99999))

    for trial in runner.get_trials():
        # TODO(rliaw): What about errored?
        if trial.status != Trial.TERMINATED:
            raise TuneError("Trial did not complete", trial)

    wait_for_log_sync()
    return runner.get_trials()
Ejemplo n.º 21
0
    def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
        """Generates Trial objects with the variant generation process.

        Uses a fixed point iteration to resolve variants. All trials
        should be able to be generated at once.

        See also: `ray.tune.suggest.variant_generator`.

        Yields:
            Trial object
        """

        if "run" not in unresolved_spec:
            raise TuneError("Must specify `run` in {}".format(unresolved_spec))
        for _ in range(num_samples):
            # Iterate over list of configs
            for unresolved_cfg in unresolved_spec["config"]:
                unresolved_spec_variant = deepcopy(unresolved_spec)
                unresolved_spec_variant["config"] = unresolved_cfg
                resolved_base_vars = CustomVariantGenerator._extract_resolved_base_vars(unresolved_cfg,
                                                                                        unresolved_spec["config"])
                print("Resolved base cfg vars", resolved_base_vars)
                for resolved_vars, spec in generate_variants(unresolved_spec_variant):
                    resolved_vars.update(resolved_base_vars)
                    print("Resolved vars", resolved_vars)
                    trial_id = "%05d" % self._counter
                    experiment_tag = str(self._counter)
                    if resolved_vars:
                        experiment_tag += "_{}".format(
                            format_vars({k: v for k, v in resolved_vars.items() if "tag" in k}))
                    self._counter += 1
                    yield create_trial_from_spec(
                        spec,
                        output_path,
                        self._parser,
                        evaluated_params=flatten_resolved_vars(resolved_vars),
                        trial_id=trial_id,
                        experiment_tag=experiment_tag)
Ejemplo n.º 22
0
 def retry(self):
     if not self._current_cmd:
         raise TuneError("No sync command set, cannot retry.")
     cmd, kwargs = self._current_cmd
     self._sync_process = _BackgroundProcess(cmd)
     self._sync_process.start(**kwargs)
Ejemplo n.º 23
0
    def step(self):
        """Implements train() for a Function API.

        If the RunnerThread finishes without reporting "done",
        Tune will automatically provide a magic keyword __duplicate__
        along with a result with "done=True". The TrialRunner will handle the
        result accordingly (see tune/trial_runner.py).
        """
        if self._runner and self._runner.is_alive():
            # if started and alive, inform the reporter to continue and
            # generate the next result
            self._continue_semaphore.release()
        else:
            self._start()

        result = None
        while result is None and self._runner.is_alive():
            # fetch the next produced result
            try:
                result = self._results_queue.get(block=True,
                                                 timeout=RESULT_FETCH_TIMEOUT)
            except queue.Empty:
                pass

        # if no result were found, then the runner must no longer be alive
        if result is None:
            # Try one last time to fetch results in case results were reported
            # in between the time of the last check and the termination of the
            # thread runner.
            try:
                result = self._results_queue.get(block=False)
            except queue.Empty:
                pass

        # check if error occurred inside the thread runner
        if result is None:
            # only raise an error from the runner if all results are consumed
            self._report_thread_runner_error(block=True)

            # Under normal conditions, this code should never be reached since
            # this branch should only be visited if the runner thread raised
            # an exception. If no exception were raised, it means that the
            # runner thread never reported any results which should not be
            # possible when wrapping functions with `wrap_function`.
            raise TuneError(
                ("Wrapped function ran until completion without reporting "
                 "results or raising an exception."))

        else:
            if not self._error_queue.empty():
                logger.warning(
                    ("Runner error waiting to be raised in main thread. "
                     "Logging all available results first."))

        # This keyword appears if the train_func using the Function API
        # finishes without "done=True". This duplicates the last result, but
        # the TrialRunner will not log this result again.
        if RESULT_DUPLICATE in result:
            new_result = self._last_result.copy()
            new_result.update(result)
            result = new_result

        self._last_result = result
        if self._status_reporter.has_new_checkpoint():
            result[SHOULD_CHECKPOINT] = True
        return result
Ejemplo n.º 24
0
Archivo: trial.py Proyecto: ujvl/ray
 def _registration_check(cls, trainable_name):
     if not has_trainable(trainable_name):
         # Make sure rllib agents are registered
         from ray import rllib  # noqa: F401
         if not has_trainable(trainable_name):
             raise TuneError("Unknown trainable: " + trainable_name)
Ejemplo n.º 25
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        self._updated_queue = False

        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        with warn_if_slow("callbacks.on_step_begin"):
            self._callbacks.on_step_begin(iteration=self._iteration,
                                          trials=self._trials)

        # This will contain the next trial to start
        next_trial = self._get_next_trial()  # blocking

        # Create pending trials. If the queue was updated before, only
        # continue updating if this was successful (next_trial is not None)
        if not self._updated_queue or (self._updated_queue and next_trial):
            num_pending_trials = len(
                [t for t in self._trials if t.status == Trial.PENDING])
            while num_pending_trials < self._max_pending_trials:
                if not self._update_trial_queue(blocking=False):
                    break
                num_pending_trials += 1

        # Update status of staged placement groups
        self.trial_executor.stage_and_update_status(self._trials)

        def _start_trial(trial: Trial) -> bool:
            """Helper function to start trial and call callbacks"""
            with warn_if_slow("start_trial"):
                if self.trial_executor.start_trial(trial):
                    self._callbacks.on_trial_start(iteration=self._iteration,
                                                   trials=self._trials,
                                                   trial=trial)
                    return True
                return False

        may_handle_events = True
        if next_trial is not None:
            if _start_trial(next_trial):
                may_handle_events = False
            elif next_trial.status != Trial.ERROR:
                # Only try to start another trial if previous trial startup
                # did not error (e.g. it just didn't start because its
                # placement group is not ready, yet).
                next_trial = self.trial_executor.get_staged_trial()
                if next_trial is not None:
                    if _start_trial(next_trial):
                        may_handle_events = False

        if may_handle_events:
            if self.trial_executor.get_running_trials():
                timeout = None
                if self.trial_executor.in_staging_grace_period():
                    timeout = 0.1
                self._process_events(timeout=timeout)  # blocking
            else:
                self.trial_executor.on_no_available_trials(self)

        self._stop_experiment_if_needed()

        try:
            self.checkpoint()
        except Exception as e:
            logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_stop_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
        with warn_if_slow("callbacks.on_step_end"):
            self._callbacks.on_step_end(iteration=self._iteration,
                                        trials=self._trials)
Ejemplo n.º 26
0
 def register(self, category, key, value):
     if category not in KNOWN_CATEGORIES:
         raise TuneError("Unknown category {} not among {}".format(
             category, KNOWN_CATEGORIES))
     self._all_objects[(category, key)] = value
Ejemplo n.º 27
0
    def __init__(
        self,
        restore_path: str = None,
        trainable: Optional[Union[str, Callable, Type[Trainable],
                                  BaseTrainer, ]] = None,
        param_space: Optional[Dict[str, Any]] = None,
        tune_config: Optional[TuneConfig] = None,
        run_config: Optional[RunConfig] = None,
        _tuner_kwargs: Optional[Dict] = None,
    ):
        # Restored from Tuner checkpoint.
        if restore_path:
            trainable_ckpt = os.path.join(restore_path, _TRAINABLE_PKL)
            with open(trainable_ckpt, "rb") as fp:
                trainable = pickle.load(fp)

            tuner_ckpt = os.path.join(restore_path, _TUNER_PKL)
            with open(tuner_ckpt, "rb") as fp:
                tuner = pickle.load(fp)
                self.__dict__.update(tuner.__dict__)

            self._is_restored = True
            self._trainable = trainable
            self._experiment_checkpoint_dir = restore_path
            return

        # Start from fresh
        if not trainable:
            raise TuneError("You need to provide a trainable to tune.")

        # If no run config was passed to Tuner directly, use the one from the Trainer,
        # if available
        if not run_config and isinstance(trainable, BaseTrainer):
            run_config = trainable.run_config

        self._is_restored = False
        self._trainable = trainable
        self._tune_config = tune_config or TuneConfig()
        self._run_config = run_config or RunConfig()
        self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
        self._experiment_checkpoint_dir = self._setup_create_experiment_checkpoint_dir(
            self._run_config)

        # Not used for restored Tuner.
        self._param_space = param_space or {}
        self._process_scaling_config()

        # This needs to happen before `tune.run()` is kicked in.
        # This is because currently tune does not exit gracefully if
        # run in ray client mode - if crash happens, it just exits immediately
        # without allowing for checkpointing tuner and trainable.
        # Thus this has to happen before tune.run() so that we can have something
        # to restore from.
        tuner_ckpt = os.path.join(self._experiment_checkpoint_dir, _TUNER_PKL)
        with open(tuner_ckpt, "wb") as fp:
            pickle.dump(self, fp)

        trainable_ckpt = os.path.join(self._experiment_checkpoint_dir,
                                      _TRAINABLE_PKL)
        with open(trainable_ckpt, "wb") as fp:
            pickle.dump(self._trainable, fp)
Ejemplo n.º 28
0
def _make_scheduler(args):
    if args.scheduler in _SCHEDULERS:
        return _SCHEDULERS[args.scheduler](**args.scheduler_config)
    else:
        raise TuneError("Unknown scheduler: {}, should be one of {}".format(
            args.scheduler, _SCHEDULERS.keys()))
Ejemplo n.º 29
0
 def wait(self):
     result = super(MaybeFailingProcess, self).wait()
     if self.should_fail:
         raise TuneError("Syncing failed.")
     return result
Ejemplo n.º 30
0
def _tune_error(msg):
    raise TuneError(msg)