Ejemplo n.º 1
0
    def __init__(self,
                 time_attr: str = "time_total_s",
                 metric: Optional[str] = None,
                 mode: Optional[str] = None,
                 perturbation_interval: float = 60.0,
                 hyperparam_bounds: Dict = None,
                 quantile_fraction: float = 0.25,
                 log_config: bool = True,
                 require_attrs: bool = True,
                 synch: bool = False):

        gpy_available, sklearn_available = import_pb2_dependencies()
        if not gpy_available:
            raise RuntimeError("Please install GPy to use PB2.")

        if not sklearn_available:
            raise RuntimeError("Please install scikit-learn to use PB2.")

        hyperparam_bounds = hyperparam_bounds or {}
        for value in hyperparam_bounds.values():
            if not isinstance(value, (list, tuple)) or len(value) != 2:
                raise ValueError("`hyperparam_bounds` values must either be "
                                 "a list or tuple of size 2, but got {} "
                                 "instead".format(value))

        if not hyperparam_bounds:
            raise TuneError("`hyperparam_bounds` must be specified to use "
                            "PB2 scheduler.")

        super(PB2, self).__init__(time_attr=time_attr,
                                  metric=metric,
                                  mode=mode,
                                  perturbation_interval=perturbation_interval,
                                  hyperparam_mutations=hyperparam_bounds,
                                  quantile_fraction=quantile_fraction,
                                  resample_probability=0,
                                  custom_explore_fn=explore,
                                  log_config=log_config,
                                  require_attrs=require_attrs,
                                  synch=synch)

        self.last_exploration_time = 0  # when we last explored
        self.data = pd.DataFrame()

        self._hyperparam_bounds = hyperparam_bounds

        # Current = trials running that have already re-started after reaching
        #           the checkpoint. When exploring we care if these trials
        #           are already in or scheduled to be in the next round.
        self.current = None
Ejemplo n.º 2
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """

        if self._can_launch_more():
            self._launch_trial()
        elif self._running:
            self._process_events()
        else:
            for trial in self._trials:
                if trial.status == Trial.PENDING:
                    if not self.has_resources(trial.resources):
                        raise TuneError(
                            "Insufficient cluster resources to launch trial",
                            (trial.resources, self._avail_resources))
                elif trial.status == Trial.PAUSED:
                    raise TuneError(
                        "There are paused trials, but no more pending "
                        "trials with sufficient resources.")
            raise TuneError("Called step when all trials finished?")
Ejemplo n.º 3
0
    def _sync_trial_checkpoint(self, trial: "Trial",
                               checkpoint: _TrackedCheckpoint):
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            return

        trial_syncer = self._get_trial_syncer(trial)
        # If the sync_function is False, syncing to driver is disabled.
        # In every other case (valid values include None, True Callable,
        # NodeSyncer) syncing to driver is enabled.
        if trial.sync_on_checkpoint and self._sync_function is not False:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                trial_syncer.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(f"Trial {trial}: An error occurred during the "
                             f"checkpoint pre-sync wait: {e}")
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if trial_syncer.sync_down():
                    trial_syncer.wait()
                else:
                    logger.error(f"Trial {trial}: Checkpoint sync skipped. "
                                 f"This should not happen.")
            except TuneError as e:
                if trial.uses_cloud_checkpointing:
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error(f"Trial {trial}: Sync error: {e}")
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not trial.uses_cloud_checkpointing:
                if not os.path.exists(checkpoint.dir_or_data):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down. "
                                    "Are you running on a Kubernetes or "
                                    "managed cluster? rsync will not function "
                                    "due to a lack of SSH functionality. "
                                    "You'll need to use cloud-checkpointing "
                                    "if that's the case, see instructions "
                                    "here: {} .".format(
                                        trial,
                                        checkpoint.dir_or_data,
                                        CLOUD_CHECKPOINTING_URL,
                                    ))
Ejemplo n.º 4
0
    def validate(formats):
        """Validates formats.

        Raises:
            ValueError if the format is unknown.
        """
        for i in range(len(formats)):
            formats[i] = formats[i].strip().lower()
            if formats[i] not in [
                    ExportFormat.CHECKPOINT, ExportFormat.MODEL,
                    ExportFormat.ONNX, ExportFormat.H5
            ]:
                raise TuneError("Unsupported import/export format: " +
                                formats[i])
Ejemplo n.º 5
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        next_trial = self._get_next_trial()
        if next_trial is not None:
            self._launch_trial(next_trial)
        elif self._running:
            self._process_events()
        else:
            for trial in self._trials:
                if trial.status == Trial.PENDING:
                    if not self.has_resources(trial.resources):
                        raise TuneError(
                            ("Insufficient cluster resources to launch trial: "
                             "trial requested {} but the cluster only has {} "
                             "available. Pass `queue_trials=True` in "
                             "ray.tune.run_experiments() or on the command "
                             "line to queue trials until the cluster scales "
                             "up. {}").format(
                                 trial.resources.summary_string(),
                                 self._avail_resources.summary_string(),
                                 trial._get_trainable_cls().resource_help(
                                     trial.config)))
                elif trial.status == Trial.PAUSED:
                    raise TuneError(
                        "There are paused trials, but no more pending "
                        "trials with sufficient resources.")
            raise TuneError("Called step when all trials finished?")

        if self._server:
            self._process_requests()

            if self.is_finished():
                self._server.shutdown()
Ejemplo n.º 6
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    try:
        args, _ = parser.parse_known_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)
    if "resources_per_trial" in spec:
        trial_kwargs["resources"] = json_to_resources(
            spec["resources_per_trial"])
    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(spec["local_dir"], output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_at_end=args.checkpoint_at_end,
        sync_on_checkpoint=args.sync_on_checkpoint,
        keep_checkpoints_num=args.keep_checkpoints_num,
        checkpoint_score_attr=args.checkpoint_score_attr,
        export_formats=spec.get("export_formats", []),
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        trial_name_creator=spec.get("trial_name_creator"),
        trial_dirname_creator=spec.get("trial_dirname_creator"),
        loggers=spec.get("loggers"),
        log_to_file=spec.get("log_to_file"),
        # str(None) doesn't create None
        sync_to_driver_fn=spec.get("sync_to_driver"),
        max_failures=args.max_failures,
        **trial_kwargs)
Ejemplo n.º 7
0
    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""

        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            if result[criteria] >= stop_value:
                return True

        return False
Ejemplo n.º 8
0
    def wait_for_all(self):
        failed_syncs = {}
        for trial, sync_process in self._sync_processes.items():
            try:
                sync_process.wait()
            except Exception as e:
                failed_syncs[trial] = e

        if failed_syncs:
            sync_str = "\n".join(
                [f"  {trial}: {e}" for trial, e in failed_syncs.items()]
            )
            raise TuneError(
                f"At least one trial failed to sync down when waiting for all "
                f"trials to sync: \n{sync_str}"
            )
Ejemplo n.º 9
0
    def _train(self):
        time.sleep(
            self.config.get("script_min_iter_time_s",
                            self._default_config["script_min_iter_time_s"]))
        result = self._status_reporter._get_and_clear_status()
        while result is None:
            time.sleep(1)
            result = self._status_reporter._get_and_clear_status()
        if result.timesteps_total is None:
            raise TuneError("Must specify timesteps_total in result", result)

        result = result._replace(
            timesteps_this_iter=(result.timesteps_total -
                                 self._last_reported_timestep))
        self._last_reported_timestep = result.timesteps_total

        return result
Ejemplo n.º 10
0
 def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
     assert max_retries > 0
     last_error = None
     for _ in range(max_retries - 1):
         try:
             self.wait()
         except Exception as e:
             logger.error(
                 f"Caught sync error: {e}. "
                 f"Retrying after sleeping for {backoff_s} seconds...")
             last_error = e
             time.sleep(backoff_s)
             self.retry()
             continue
         return
     raise TuneError(
         f"Failed sync even after {max_retries} retries.") from last_error
Ejemplo n.º 11
0
def _try_resolve(v):
    if isinstance(v, sample_from):
        # Function to sample from
        return False, v.func
    elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
        # Lambda function in eval syntax
        return False, lambda spec: eval(
            v["eval"], _STANDARD_IMPORTS, {"spec": spec})
    elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
        # Grid search values
        grid_values = v["grid_search"]
        if not isinstance(grid_values, list):
            raise TuneError(
                "Grid search expected list of values, got: {}".format(
                    grid_values))
        return False, grid_values
    return True, v
Ejemplo n.º 12
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        with warn_if_slow("callbacks.on_step_begin"):
            self._callbacks.on_step_begin(
                iteration=self._iteration, trials=self._trials)
        next_trial = self._get_next_trial()  # blocking
        if next_trial is not None:
            with warn_if_slow("start_trial"):
                self.trial_executor.start_trial(next_trial)
                self._callbacks.on_trial_start(
                    iteration=self._iteration,
                    trials=self._trials,
                    trial=next_trial)
        elif self.trial_executor.get_running_trials():
            self._process_events()  # blocking
        else:
            self.trial_executor.on_no_available_trials(self)

        self._stop_experiment_if_needed()

        try:
            with warn_if_slow("experiment_checkpoint"):
                self.checkpoint()
        except Exception as e:
            logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_stop_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
        with warn_if_slow("callbacks.on_step_end"):
            self._callbacks.on_step_end(
                iteration=self._iteration, trials=self._trials)
Ejemplo n.º 13
0
    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""
        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            elif isinstance(criteria, dict):
                raise ValueError(
                    "Stopping criteria is now flattened by default. "
                    "Use forward slashes to nest values `key1/key2/key3`.")
            elif result[criteria] >= stop_value:
                return True
        return False
Ejemplo n.º 14
0
    def on_checkpoint(
        self,
        iteration: int,
        trials: List["Trial"],
        trial: "Trial",
        checkpoint: _TrackedCheckpoint,
        **info,
    ):
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            return

        if self._sync_trial_dir(
                trial, force=trial.sync_on_checkpoint,
                wait=True) and not os.path.exists(checkpoint.dir_or_data):
            raise TuneError(
                f"Trial {trial}: Checkpoint path {checkpoint.dir_or_data} not "
                "found after successful sync down.")
Ejemplo n.º 15
0
def _try_resolve(v) -> Tuple[bool, Any]:
    if isinstance(v, Domain):
        # Domain to sample from
        return False, v
    elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
        # Lambda function in eval syntax
        return False, Function(
            lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec}))
    elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
        # Grid search values
        grid_values = v["grid_search"]
        if not isinstance(grid_values, list):
            raise TuneError(
                "Grid search expected list of values, got: {}".format(
                    grid_values))
        return False, Categorical(grid_values).grid()
    return True, v
Ejemplo n.º 16
0
    def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint):
        if checkpoint.storage == Checkpoint.MEMORY:
            return

        # Local import to avoid circular dependencies between syncer and
        # trainable
        from ray.tune.durable_trainable import DurableTrainable

        trial_syncer = self._get_trial_syncer(trial)
        # If the sync_function is False, syncing to driver is disabled.
        # In every other case (valid values include None, True Callable,
        # NodeSyncer) syncing to driver is enabled.
        if trial.sync_on_checkpoint and self._sync_function is not False:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                trial_syncer.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(
                    "Trial %s: An error occurred during the "
                    "checkpoint pre-sync wait - %s", trial, str(e))
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if trial_syncer.sync_down():
                    trial_syncer.wait()
                else:
                    logger.error(
                        "Trial %s: Checkpoint sync skipped. "
                        "This should not happen.", trial)
            except TuneError as e:
                if issubclass(trial.get_trainable_cls(), DurableTrainable):
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error("Trial %s: Sync error - %s", trial, str(e))
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not issubclass(trial.get_trainable_cls(), DurableTrainable):
                if not os.path.exists(checkpoint.value):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down.".format(
                                        trial, checkpoint.value))
Ejemplo n.º 17
0
    def on_checkpoint(self, checkpoint):
        """Hook for handling checkpoints taken by the Trainable.

        Args:
            checkpoint (Checkpoint): Checkpoint taken.
        """
        if checkpoint.storage == Checkpoint.MEMORY:
            # TODO(ujvl): Handle this separately to avoid restoration failure.
            self.checkpoint_manager.on_checkpoint(checkpoint)
            return
        if self.sync_on_checkpoint:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                self.result_logger.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(
                    "Trial %s: An error occurred during the "
                    "checkpoint pre-sync wait.", str(e))
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if self.result_logger.sync_down():
                    self.result_logger.wait()
                else:
                    logger.error(
                        "Trial %s: Checkpoint sync skipped. "
                        "This should not happen.", self)
            except TuneError as e:
                if issubclass(self.get_trainable_cls(), DurableTrainable):
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error("Trial %s: Sync error - %s", self, str(e))
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not issubclass(self.get_trainable_cls(), DurableTrainable):
                if not os.path.exists(checkpoint.value):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down.".format(
                                        self, checkpoint.value))
        self.checkpoint_manager.on_checkpoint(checkpoint)
Ejemplo n.º 18
0
def json_to_resources(data):
    if data is None or data == "null":
        return None
    if isinstance(data, string_types):
        data = json.loads(data)
    for k in data:
        if k in ["driver_cpu_limit", "driver_gpu_limit"]:
            raise TuneError(
                "The field `{}` is no longer supported. Use `extra_cpu` "
                "or `extra_gpu` instead.".format(k))
        if k not in Resources._fields:
            raise ValueError(
                "Unknown resource field {}, must be one of {}".format(
                    k, Resources._fields))
    return Resources(data.get("cpu", 1), data.get("gpu", 0),
                     data.get("extra_cpu", 0), data.get("extra_gpu", 0),
                     data.get("custom_resources"),
                     data.get("extra_custom_resources"))
Ejemplo n.º 19
0
    def start_trial(self, trial, checkpoint=None):
        # Reserve node before starting trial
        resources = trial.resources

        # Check for required GPUs first
        required = resources.gpu_total()
        if required > 0:
            resource_attr = "gpu"
        else:
            # No GPU, just use CPU
            resource_attr = "cpu"
            required = resources.cpu_total()

        # Compute nodes required to fulfill trial resources request
        custom_resources = {}
        for node_id, node_resource in self._resources_by_node.items():
            # Compute resource remaining on each node
            node_capacity = node_resource.get_res_total(node_id)
            committed_capacity = self._committed_resources.get_res_total(
                node_id)
            remaining_capacity = node_capacity - committed_capacity
            node_procs = getattr(node_resource, resource_attr, 0)
            available = node_procs * remaining_capacity

            if available == 0:
                continue

            if required <= available:
                custom_resources[node_id] = required / node_procs
                required = 0
                break
            else:
                custom_resources[node_id] = remaining_capacity
                required -= available

        if required > 0:
            raise TuneError(f"Unable to start trial {trial.trial_id}. "
                            f"Not enough nodes")

        # Update trial node affinity configuration and
        # extra_custom_resources requirement
        trial.config.update(__ray_node_affinity__=custom_resources)
        trial.resources.extra_custom_resources.update(custom_resources)
        super().start_trial(trial, checkpoint)
Ejemplo n.º 20
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    try:
        # Special case the `env` param for RLlib by automatically
        # moving it into the `config` section.
        if "env" in spec:
            spec["config"] = spec.get("config", {})
            spec["config"]["env"] = spec["env"]
            del spec["env"]
        args = parser.parse_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)
    if "trial_resources" in spec:
        trial_kwargs["resources"] = json_to_resources(spec["trial_resources"])
    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(args.local_dir, output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        checkpoint_freq=args.checkpoint_freq,
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        upload_dir=args.upload_dir,
        max_failures=args.max_failures,
        **trial_kwargs)
Ejemplo n.º 21
0
def run_experiments(experiments, scheduler=None, **ray_args):
    if scheduler is None:
        scheduler = FIFOScheduler()
    runner = TrialRunner(scheduler)

    for name, spec in experiments.items():
        for trial in generate_trials(spec, name):
            runner.add_trial(trial)
    print(runner.debug_string())

    ray.init(**ray_args)

    while not runner.is_finished():
        runner.step()
        print(runner.debug_string())

    for trial in runner.get_trials():
        if trial.status != Trial.TERMINATED:
            raise TuneError("Trial did not complete", trial)

    return runner.get_trials()
    def train(self):
        if not self._initialize_ok:
            raise ValueError(
                "Agent initialization failed, see previous errors")

        now = time.time()
        time.sleep(self.config["script_min_iter_time_s"])

        result = self._status_reporter._get_and_clear_status()
        while result is None:
            time.sleep(1)
            result = self._status_reporter._get_and_clear_status()
        if result.timesteps_total is None:
            raise TuneError("Must specify timesteps_total in result", result)

        # Include the negative loss to use as a stopping condition
        if result.mean_loss is not None:
            neg_loss = -result.mean_loss
        else:
            neg_loss = result.neg_mean_loss

        result = result._replace(
            experiment_id=self._experiment_id,
            neg_mean_loss=neg_loss,
            training_iteration=self.iteration,
            time_this_iter_s=now - self._last_reported_time,
            timesteps_this_iter=(result.timesteps_total -
                                 self._last_reported_timestep),
            time_total_s=now - self._start_time,
            pid=os.getpid(),
            hostname=os.uname()[1])

        if result.timesteps_total:
            self._last_reported_timestep = result.timesteps_total
        self._last_reported_time = now
        self._iteration += 1

        self._result_logger.on_result(result)

        return result
def import_function(file_path, function_name):
    # strong assumption here that we're in a new process
    file_path = os.path.expanduser(file_path)
    sys.path.insert(0, os.path.dirname(file_path))
    if hasattr(importlib, "util"):
        # Python 3.4+
        spec = importlib.util.spec_from_file_location("external_file",
                                                      file_path)
        external_file = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(external_file)
    elif hasattr(importlib, "machinery"):
        # Python 3.3
        from importlib.machinery import SourceFileLoader
        external_file = SourceFileLoader("external_file",
                                         file_path).load_module()
    else:
        # Python 2.x
        import imp
        external_file = imp.load_source("external_file", file_path)
    if not external_file:
        raise TuneError("Unable to import file at {}".format(file_path))
    return getattr(external_file, function_name)
Ejemplo n.º 24
0
def run_experiments(experiments,
                    scheduler=None,
                    with_server=False,
                    server_port=TuneServer.DEFAULT_PORT,
                    verbose=True):

    # Make sure rllib agents are registered
    from ray import rllib  # noqa # pylint: disable=unused-import

    if scheduler is None:
        scheduler = FIFOScheduler()

    runner = TrialRunner(scheduler,
                         launch_web_server=with_server,
                         server_port=server_port)

    for name, spec in experiments.items():
        for trial in generate_trials(spec, name):
            trial.set_verbose(verbose)
            runner.add_trial(trial)
    print(runner.debug_string(max_debug=99999))

    last_debug = 0
    while not runner.is_finished():
        runner.step()
        if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
            print(runner.debug_string())
            last_debug = time.time()

    print(runner.debug_string(max_debug=99999))

    for trial in runner.get_trials():
        # TODO(rliaw): What about errored?
        if trial.status != Trial.TERMINATED:
            raise TuneError("Trial did not complete", trial)

    wait_for_log_sync()
    return runner.get_trials()
Ejemplo n.º 25
0
    def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
        """Generates Trial objects with the variant generation process.

        Uses a fixed point iteration to resolve variants. All trials
        should be able to be generated at once.

        See also: `ray.tune.suggest.variant_generator`.

        Yields:
            Trial object
        """

        if "run" not in unresolved_spec:
            raise TuneError("Must specify `run` in {}".format(unresolved_spec))
        for _ in range(num_samples):
            # Iterate over list of configs
            for unresolved_cfg in unresolved_spec["config"]:
                unresolved_spec_variant = deepcopy(unresolved_spec)
                unresolved_spec_variant["config"] = unresolved_cfg
                resolved_base_vars = CustomVariantGenerator._extract_resolved_base_vars(unresolved_cfg,
                                                                                        unresolved_spec["config"])
                print("Resolved base cfg vars", resolved_base_vars)
                for resolved_vars, spec in generate_variants(unresolved_spec_variant):
                    resolved_vars.update(resolved_base_vars)
                    print("Resolved vars", resolved_vars)
                    trial_id = "%05d" % self._counter
                    experiment_tag = str(self._counter)
                    if resolved_vars:
                        experiment_tag += "_{}".format(
                            format_vars({k: v for k, v in resolved_vars.items() if "tag" in k}))
                    self._counter += 1
                    yield create_trial_from_spec(
                        spec,
                        output_path,
                        self._parser,
                        evaluated_params=flatten_resolved_vars(resolved_vars),
                        trial_id=trial_id,
                        experiment_tag=experiment_tag)
Ejemplo n.º 26
0
    def step(self):
        """Implements train() for a Function API.

        If the RunnerThread finishes without reporting "done",
        Tune will automatically provide a magic keyword __duplicate__
        along with a result with "done=True". The TrialRunner will handle the
        result accordingly (see tune/trial_runner.py).
        """
        if self._runner and self._runner.is_alive():
            # if started and alive, inform the reporter to continue and
            # generate the next result
            self._continue_semaphore.release()
        else:
            self._start()

        result = None
        while result is None and self._runner.is_alive():
            # fetch the next produced result
            try:
                result = self._results_queue.get(block=True,
                                                 timeout=RESULT_FETCH_TIMEOUT)
            except queue.Empty:
                pass

        # if no result were found, then the runner must no longer be alive
        if result is None:
            # Try one last time to fetch results in case results were reported
            # in between the time of the last check and the termination of the
            # thread runner.
            try:
                result = self._results_queue.get(block=False)
            except queue.Empty:
                pass

        # check if error occurred inside the thread runner
        if result is None:
            # only raise an error from the runner if all results are consumed
            self._report_thread_runner_error(block=True)

            # Under normal conditions, this code should never be reached since
            # this branch should only be visited if the runner thread raised
            # an exception. If no exception were raised, it means that the
            # runner thread never reported any results which should not be
            # possible when wrapping functions with `wrap_function`.
            raise TuneError(
                ("Wrapped function ran until completion without reporting "
                 "results or raising an exception."))

        else:
            if not self._error_queue.empty():
                logger.warning(
                    ("Runner error waiting to be raised in main thread. "
                     "Logging all available results first."))

        # This keyword appears if the train_func using the Function API
        # finishes without "done=True". This duplicates the last result, but
        # the TrialRunner will not log this result again.
        if RESULT_DUPLICATE in result:
            new_result = self._last_result.copy()
            new_result.update(result)
            result = new_result

        self._last_result = result
        if self._status_reporter.has_new_checkpoint():
            result[SHOULD_CHECKPOINT] = True
        return result
Ejemplo n.º 27
0
def _tune_error(msg):
    raise TuneError(msg)
Ejemplo n.º 28
0
 def register(self, category, key, value):
     if category not in KNOWN_CATEGORIES:
         raise TuneError("Unknown category {} not among {}".format(
             category, KNOWN_CATEGORIES))
     self._all_objects[(category, key)] = value
Ejemplo n.º 29
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        self._updated_queue = False

        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        with warn_if_slow("callbacks.on_step_begin"):
            self._callbacks.on_step_begin(iteration=self._iteration,
                                          trials=self._trials)

        # This will contain the next trial to start
        next_trial = self._get_next_trial()  # blocking

        # Create pending trials. If the queue was updated before, only
        # continue updating if this was successful (next_trial is not None)
        if not self._updated_queue or (self._updated_queue and next_trial):
            num_pending_trials = len(
                [t for t in self._trials if t.status == Trial.PENDING])
            while num_pending_trials < self._max_pending_trials:
                if not self._update_trial_queue(blocking=False):
                    break
                num_pending_trials += 1

        # Update status of staged placement groups
        self.trial_executor.stage_and_update_status(self._trials)

        def _start_trial(trial: Trial) -> bool:
            """Helper function to start trial and call callbacks"""
            with warn_if_slow("start_trial"):
                if self.trial_executor.start_trial(trial):
                    self._callbacks.on_trial_start(iteration=self._iteration,
                                                   trials=self._trials,
                                                   trial=trial)
                    return True
                return False

        may_handle_events = True
        if next_trial is not None:
            if _start_trial(next_trial):
                may_handle_events = False
            elif next_trial.status != Trial.ERROR:
                # Only try to start another trial if previous trial startup
                # did not error (e.g. it just didn't start because its
                # placement group is not ready, yet).
                next_trial = self.trial_executor.get_staged_trial()
                if next_trial is not None:
                    if _start_trial(next_trial):
                        may_handle_events = False

        if may_handle_events:
            if self.trial_executor.get_running_trials():
                timeout = None
                if self.trial_executor.in_staging_grace_period():
                    timeout = 0.1
                self._process_events(timeout=timeout)  # blocking
            else:
                self.trial_executor.on_no_available_trials(self)

        self._stop_experiment_if_needed()

        try:
            self.checkpoint()
        except Exception as e:
            logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_stop_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
        with warn_if_slow("callbacks.on_step_end"):
            self._callbacks.on_step_end(iteration=self._iteration,
                                        trials=self._trials)
Ejemplo n.º 30
0
Archivo: trial.py Proyecto: ujvl/ray
 def _registration_check(cls, trainable_name):
     if not has_trainable(trainable_name):
         # Make sure rllib agents are registered
         from ray import rllib  # noqa: F401
         if not has_trainable(trainable_name):
             raise TuneError("Unknown trainable: " + trainable_name)