Esempio n. 1
0
    def testSameCheckpoint(self):
        checkpoint_manager = CheckpointManager(
            1, "i", delete_fn=lambda c: os.remove(c.value))

        tmpfiles = []
        for i in range(3):
            _, tmpfile = tempfile.mkstemp()
            with open(tmpfile, "wt") as fp:
                fp.write("")
            tmpfiles.append(tmpfile)

        checkpoints = [
            Checkpoint(Checkpoint.PERSISTENT, tmpfiles[0],
                       self.mock_result(5)),
            Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1],
                       self.mock_result(10)),
            Checkpoint(Checkpoint.PERSISTENT, tmpfiles[2],
                       self.mock_result(0)),
            Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1],
                       self.mock_result(20))
        ]
        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)
            self.assertTrue(os.path.exists(checkpoint.value))

        for tmpfile in tmpfiles:
            if os.path.exists(tmpfile):
                os.remove(tmpfile)
Esempio n. 2
0
    def testOnCheckpointUnavailableAttribute(self):
        """
        Tests that an error is logged when the associated result of the
        checkpoint has no checkpoint score attribute.
        """
        keep_checkpoints_num = 1
        checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")

        no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
        with patch.object(logger, "error") as log_error_mock:
            checkpoint_manager.on_checkpoint(no_attr_checkpoint)
            log_error_mock.assert_called_once()
            # The newest checkpoint should still be set despite this error.
            assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint
Esempio n. 3
0
    def testBestCheckpoints(self):
        """
        Tests that the best checkpoints are tracked and ordered correctly.
        """
        keep_checkpoints_num = 4
        checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
        checkpoints = [
            Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
            for i in range(16)
        ]
        random.shuffle(checkpoints)

        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)

        best_checkpoints = checkpoint_manager.best_checkpoints()
        self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
        for i in range(len(best_checkpoints)):
            self.assertEqual(best_checkpoints[i].value, i + 12)
Esempio n. 4
0
    def testOnCheckpointUnordered(self):
        """
        Tests priorities that aren't inserted in ascending order. Also tests
        that the worst checkpoints are deleted when necessary.
        """
        keep_checkpoints_num = 2
        checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
        checkpoints = [
            Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
            for i in range(3, -1, -1)
        ]

        with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
            for j in range(0, len(checkpoints)):
                checkpoint_manager.on_checkpoint(checkpoints[j])
                expected_deletes = 0 if j != 3 else 1
                self.assertEqual(rmtree_mock.call_count, expected_deletes)
                self.assertEqual(checkpoint_manager.newest_checkpoint,
                                 checkpoints[j])

        best_checkpoints = checkpoint_manager.best_checkpoints()
        self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
        self.assertIn(checkpoints[0], best_checkpoints)
        self.assertIn(checkpoints[1], best_checkpoints)
Esempio n. 5
0
    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 evaluated_params=None,
                 experiment_tag="",
                 resources=None,
                 stopping_criterion=None,
                 remote_checkpoint_dir=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=TRAINING_ITERATION,
                 export_formats=None,
                 restore_path=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_to_driver_fn=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """
        validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        #: Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        trainable_cls = self.get_trainable_cls()
        if trainable_cls and hasattr(trainable_cls,
                                     "default_resource_request"):
            default_resources = trainable_cls.default_resource_request(
                self.config)
            if default_resources:
                if resources:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))
                resources = default_resources
        self.location = Location()
        self.resources = resources or Resources(cpu=1, gpu=0)
        self.stopping_criterion = stopping_criterion or {}
        self.loggers = loggers
        self.sync_to_driver_fn = sync_to_driver_fn
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")

        # stores in memory max/min/last result for each metric by trial
        self.metric_analysis = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.result_logger = None
        self.last_debug = 0
        self.error_file = None
        self.error_msg = None
        self.custom_trial_name = None

        # Checkpointing fields
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.sync_on_checkpoint = sync_on_checkpoint
        newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num, checkpoint_score_attr,
            checkpoint_deleter(str(self), self.runner))
        self.checkpoint_manager.newest_checkpoint = newest_checkpoint

        # Restoration fields
        self.restoring_from = None
        self.num_failures = 0
        self.num_consecutive_start_attempts = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        self._nonjson_fields = [
            "loggers",
            "sync_to_driver_fn",
            "results",
            "best_result",
            "param_config",
            "extra_arg",
        ]
        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)
Esempio n. 6
0
class Trial:
    """A trial object holds the state for one model training run.

    Trials are themselves managed by the TrialRunner class, which implements
    the event loop for submitting trial runs to a Ray cluster.

    Trials start in the PENDING state, and transition to RUNNING once started.
    On error it transitions to ERROR, otherwise TERMINATED on success.
    """

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    PAUSED = "PAUSED"
    TERMINATED = "TERMINATED"
    ERROR = "ERROR"

    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 evaluated_params=None,
                 experiment_tag="",
                 resources=None,
                 stopping_criterion=None,
                 remote_checkpoint_dir=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=TRAINING_ITERATION,
                 export_formats=None,
                 restore_path=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_to_driver_fn=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """
        validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        #: Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        trainable_cls = self.get_trainable_cls()
        if trainable_cls and hasattr(trainable_cls,
                                     "default_resource_request"):
            default_resources = trainable_cls.default_resource_request(
                self.config)
            if default_resources:
                if resources:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))
                resources = default_resources
        self.location = Location()
        self.resources = resources or Resources(cpu=1, gpu=0)
        self.stopping_criterion = stopping_criterion or {}
        self.loggers = loggers
        self.sync_to_driver_fn = sync_to_driver_fn
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")

        # stores in memory max/min/last result for each metric by trial
        self.metric_analysis = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.result_logger = None
        self.last_debug = 0
        self.error_file = None
        self.error_msg = None
        self.custom_trial_name = None

        # Checkpointing fields
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.sync_on_checkpoint = sync_on_checkpoint
        newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num, checkpoint_score_attr,
            checkpoint_deleter(str(self), self.runner))
        self.checkpoint_manager.newest_checkpoint = newest_checkpoint

        # Restoration fields
        self.restoring_from = None
        self.num_failures = 0
        self.num_consecutive_start_attempts = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        self._nonjson_fields = [
            "loggers",
            "sync_to_driver_fn",
            "results",
            "best_result",
            "param_config",
            "extra_arg",
        ]
        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

    @property
    def node_ip(self):
        return self.location.hostname

    @property
    def checkpoint(self):
        return self.checkpoint_manager.newest_checkpoint

    @classmethod
    def generate_id(cls):
        return str(uuid.uuid1().hex)[:8]

    @property
    def remote_checkpoint_dir(self):
        assert self.logdir, "Trial {}: logdir not initialized.".format(self)
        if not self.remote_checkpoint_dir_prefix:
            return None
        logdir_name = os.path.basename(self.logdir)
        return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)

    @classmethod
    def create_logdir(cls, identifier, local_dir):
        local_dir = os.path.expanduser(local_dir)
        os.makedirs(local_dir, exist_ok=True)
        return tempfile.mkdtemp(prefix="{}_{}".format(
            identifier[:MAX_LEN_IDENTIFIER], date_str()),
                                dir=local_dir)

    def init_logger(self):
        """Init logger."""
        if not self.result_logger:
            if not self.logdir:
                self.logdir = Trial.create_logdir(str(self), self.local_dir)
            else:
                os.makedirs(self.logdir, exist_ok=True)

            self.result_logger = UnifiedLogger(
                self.config,
                self.logdir,
                trial=self,
                loggers=self.loggers,
                sync_function=self.sync_to_driver_fn)

    def update_resources(self, cpu, gpu, **kwargs):
        """EXPERIMENTAL: Updates the resource requirements.

        Should only be called when the trial is not running.

        Raises:
            ValueError if trial status is running.
        """
        if self.status is Trial.RUNNING:
            raise ValueError("Cannot update resources while Trial is running.")
        self.resources = Resources(cpu, gpu, **kwargs)

    def set_runner(self, runner):
        self.runner = runner
        self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner)

    def set_location(self, location):
        """Sets the location of the trial."""
        self.location = location

    def set_status(self, status):
        """Sets the status of the trial."""
        self.status = status
        if status == Trial.RUNNING:
            if self.start_time is None:
                self.start_time = time.time()

    def close_logger(self):
        """Closes logger."""
        if self.result_logger:
            self.result_logger.close()
            self.result_logger = None

    def write_error_log(self, error_msg):
        if error_msg and self.logdir:
            self.num_failures += 1
            self.error_file = os.path.join(self.logdir, "error.txt")
            with open(self.error_file, "a+") as f:
                f.write("Failure # {} (occurred at {})\n".format(
                    self.num_failures, date_str()))
                f.write(error_msg + "\n")
            self.error_msg = error_msg

    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""
        if result.get(DONE):
            return True

        if callable(self.stopping_criterion):
            return self.stopping_criterion(self.trial_id, result)

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            elif isinstance(criteria, dict):
                raise ValueError(
                    "Stopping criteria is now flattened by default. "
                    "Use forward slashes to nest values `key1/key2/key3`.")
            elif result[criteria] >= stop_value:
                return True
        return False

    def should_checkpoint(self):
        """Whether this trial is due for checkpointing."""
        result = self.last_result or {}
        if result.get(DONE) and self.checkpoint_at_end:
            return True
        return (self.checkpoint_freq and
                result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0)

    def has_checkpoint(self):
        return self.checkpoint.value is not None

    def clear_checkpoint(self):
        self.checkpoint.value = None
        self.restoring_from = None

    def on_checkpoint(self, checkpoint):
        """Hook for handling checkpoints taken by the Trainable.

        Args:
            checkpoint (Checkpoint): Checkpoint taken.
        """
        if checkpoint.storage == Checkpoint.MEMORY:
            # TODO(ujvl): Handle this separately to avoid restoration failure.
            self.checkpoint_manager.on_checkpoint(checkpoint)
            return
        if self.sync_on_checkpoint:
            try:
                # Wait for any other syncs to finish. We need to sync again
                # after this to handle checkpoints taken mid-sync.
                self.result_logger.wait()
            except TuneError as e:
                # Errors occurring during this wait are not fatal for this
                # checkpoint, so it should just be logged.
                logger.error(
                    "Trial %s: An error occurred during the "
                    "checkpoint pre-sync wait.", str(e))
            # Force sync down and wait before tracking the new checkpoint.
            try:
                if self.result_logger.sync_down():
                    self.result_logger.wait()
                else:
                    logger.error(
                        "Trial %s: Checkpoint sync skipped. "
                        "This should not happen.", self)
            except TuneError as e:
                if issubclass(self.get_trainable_cls(), DurableTrainable):
                    # Even though rsync failed the trainable can restore
                    # from remote durable storage.
                    logger.error("Trial %s: Sync error - %s", self, str(e))
                else:
                    # If the trainable didn't have remote storage to upload
                    # to then this checkpoint may have been lost, so we
                    # shouldn't track it with the checkpoint_manager.
                    raise e
            if not issubclass(self.get_trainable_cls(), DurableTrainable):
                if not os.path.exists(checkpoint.value):
                    raise TuneError("Trial {}: Checkpoint path {} not "
                                    "found after successful sync down.".format(
                                        self, checkpoint.value))
        self.checkpoint_manager.on_checkpoint(checkpoint)

    def on_restore(self):
        """Handles restoration completion."""
        assert self.is_restoring
        self.last_result = self.restoring_from.result
        self.restoring_from = None

    def should_recover(self):
        """Returns whether the trial qualifies for retrying.

        This is if the trial has not failed more than max_failures. Note this
        may return true even when there is no checkpoint, either because
        `self.checkpoint_freq` is `0` or because the trial failed before
        a checkpoint has been made.
        """
        return self.num_failures < self.max_failures or self.max_failures < 0

    def update_last_result(self, result, terminate=False):
        result.update(trial_id=self.trial_id, done=terminate)
        if self.experiment_tag:
            result.update(experiment_tag=self.experiment_tag)
        if self.verbose and (terminate or time.time() - self.last_debug >
                             DEBUG_PRINT_INTERVAL):
            print("Result for {}:".format(self))
            print("  {}".format(pretty_print(result).replace("\n", "\n  ")))
            self.last_debug = time.time()
        self.set_location(Location(result.get("node_ip"), result.get("pid")))
        self.last_result = result
        self.last_update_time = time.time()
        self.result_logger.on_result(self.last_result)
        for metric, value in flatten_dict(result).items():
            if isinstance(value, Number):
                if metric not in self.metric_analysis:
                    self.metric_analysis[metric] = {
                        "max": value,
                        "min": value,
                        "last": value
                    }
                else:
                    self.metric_analysis[metric]["max"] = max(
                        value, self.metric_analysis[metric]["max"])
                    self.metric_analysis[metric]["min"] = min(
                        value, self.metric_analysis[metric]["min"])
                    self.metric_analysis[metric]["last"] = value

    def get_trainable_cls(self):
        return get_trainable_cls(self.trainable_name)

    def set_verbose(self, verbose):
        self.verbose = verbose

    def is_finished(self):
        return self.status in [Trial.ERROR, Trial.TERMINATED]

    @property
    def is_restoring(self):
        return self.restoring_from is not None

    def __repr__(self):
        return str(self)

    def __str__(self):
        """Combines ``env`` with ``trainable_name`` and ``trial_id``.

        Can be overridden with a custom string creator.
        """
        if self.custom_trial_name:
            return self.custom_trial_name

        if "env" in self.config:
            env = self.config["env"]
            if isinstance(env, type):
                env = env.__name__
            identifier = "{}_{}".format(self.trainable_name, env)
        else:
            identifier = self.trainable_name
        identifier += "_" + self.trial_id
        return identifier.replace("/", "_")

    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING, and flushes the result logger.
        Note this can only occur if the trial holds a PERSISTENT checkpoint.
        """
        assert self.checkpoint.storage == Checkpoint.PERSISTENT, (
            "Checkpoint must not be in-memory.")
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["result_logger"] = None
        if self.result_logger:
            self.result_logger.flush(sync_down=False)
            state["__logger_started__"] = True
        else:
            state["__logger_started__"] = False
        return copy.deepcopy(state)

    def __setstate__(self, state):
        logger_started = state.pop("__logger_started__")
        state["resources"] = json_to_resources(state["resources"])

        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        validate_trainable(self.trainable_name)
        if logger_started:
            self.init_logger()
Esempio n. 7
0
    def __init__(
        self,
        trainable_name: str,
        config: Optional[Dict] = None,
        trial_id: Optional[str] = None,
        local_dir: Optional[str] = DEFAULT_RESULTS_DIR,
        evaluated_params: Optional[Dict] = None,
        experiment_tag: str = "",
        resources: Optional[Resources] = None,
        placement_group_factory: Optional[PlacementGroupFactory] = None,
        stopping_criterion: Optional[Dict[str, float]] = None,
        remote_checkpoint_dir: Optional[str] = None,
        sync_function_tpl: Optional[str] = None,
        checkpoint_freq: int = 0,
        checkpoint_at_end: bool = False,
        sync_on_checkpoint: bool = True,
        keep_checkpoints_num: Optional[int] = None,
        checkpoint_score_attr: str = TRAINING_ITERATION,
        export_formats: Optional[List[str]] = None,
        restore_path: Optional[str] = None,
        trial_name_creator: Optional[Callable[["Trial"], str]] = None,
        trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
        log_to_file: Optional[str] = None,
        max_failures: int = 0,
        stub: bool = False,
        _setup_default_resource: bool = True,
    ):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.

        Args:
            _setup_default_resource: Whether to set up default resources.
                When initializing trials from checkpoints, this field is set to false,
                so that setting up default resources can be delayed till after
                ``trial.config`` is loaded from checkpoints.
        """
        # If this is set, trainables are not validated or looked up.
        # This can be used e.g. to initialize Trial objects from checkpoints
        # without loading the trainable first.
        self.stub = stub

        if not self.stub:
            validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        # Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        self.location = Location()
        trainable_cls = self.get_trainable_cls()
        if trainable_cls and _setup_default_resource:
            default_resources = trainable_cls.default_resource_request(
                self.config)

            # If Trainable returns resources, do not allow manual override via
            # `resources_per_trial` by the user.
            if default_resources:
                if resources or placement_group_factory:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))

                if isinstance(default_resources, PlacementGroupFactory):
                    placement_group_factory = default_resources
                    resources = None
                else:
                    placement_group_factory = None
                    resources = default_resources

        self.placement_group_factory = _to_pg_factory(resources,
                                                      placement_group_factory)

        self.stopping_criterion = stopping_criterion or {}

        self.log_to_file = log_to_file
        # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
        if (not self.log_to_file or not isinstance(self.log_to_file, Sequence)
                or not len(self.log_to_file) == 2):
            self.log_to_file = (None, None)

        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self._last_result = {}
        self._default_result_or_future: Union[ray.ObjectRef, dict, None] = None
        self.last_update_time = -float("inf")

        # stores in memory max/min/avg/last-n-avg/last result for each
        # metric by trial
        self.metric_analysis = {}

        # keep a moving average over these last n steps
        self.n_steps = [5, 10]
        self.metric_n_steps = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.last_debug = 0
        self.error_file = None
        self.pickled_error_file = None
        self.trial_name_creator = trial_name_creator
        self.trial_dirname_creator = trial_dirname_creator
        self.custom_trial_name = None
        self.custom_dirname = None

        # Checkpointing fields
        self.saving_to = None
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None

        if sync_function_tpl == "auto" or not isinstance(
                sync_function_tpl, str):
            sync_function_tpl = None
        self.sync_function_tpl = sync_function_tpl

        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.keep_checkpoints_num = keep_checkpoints_num
        self.checkpoint_score_attr = checkpoint_score_attr
        self.sync_on_checkpoint = sync_on_checkpoint
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num,
            checkpoint_score_attr,
            CheckpointDeleter(self._trainable_name(), self.runner),
        )

        # Restoration fields
        self.restore_path = restore_path
        self.restoring_from = None
        self.num_failures = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

        if trial_dirname_creator:
            self.custom_dirname = trial_dirname_creator(self)
            if os.path.sep in self.custom_dirname:
                raise ValueError(
                    f"Trial dirname must not contain '/'. Got {self.custom_dirname}"
                )

        self._state_json = None
        self._state_valid = False
Esempio n. 8
0
class Trial:
    """A trial object holds the state for one model training run.

    Trials are themselves managed by the TrialRunner class, which implements
    the event loop for submitting trial runs to a Ray cluster.

    Trials start in the PENDING state, and transition to RUNNING once started.
    On error it transitions to ERROR, otherwise TERMINATED on success.

    There are resources allocated to each trial. These should be specified
    using ``PlacementGroupFactory``.

    Attributes:
        trainable_name: Name of the trainable object to be executed.
        config: Provided configuration dictionary with evaluated params.
        trial_id: Unique identifier for the trial.
        local_dir: Local_dir as passed to tune.run.
        logdir: Directory where the trial logs are saved.
        evaluated_params: Evaluated parameters by search algorithm,
        experiment_tag: Identifying trial name to show in the console
        status: One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
        error_file: Path to the errors that this trial has raised.

    """

    _nonjson_fields = [
        "results",
        "best_result",
        "param_config",
        "extra_arg",
        "placement_group_factory",
    ]

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    PAUSED = "PAUSED"
    TERMINATED = "TERMINATED"
    ERROR = "ERROR"

    def __init__(
        self,
        trainable_name: str,
        config: Optional[Dict] = None,
        trial_id: Optional[str] = None,
        local_dir: Optional[str] = DEFAULT_RESULTS_DIR,
        evaluated_params: Optional[Dict] = None,
        experiment_tag: str = "",
        resources: Optional[Resources] = None,
        placement_group_factory: Optional[PlacementGroupFactory] = None,
        stopping_criterion: Optional[Dict[str, float]] = None,
        remote_checkpoint_dir: Optional[str] = None,
        sync_function_tpl: Optional[str] = None,
        checkpoint_freq: int = 0,
        checkpoint_at_end: bool = False,
        sync_on_checkpoint: bool = True,
        keep_checkpoints_num: Optional[int] = None,
        checkpoint_score_attr: str = TRAINING_ITERATION,
        export_formats: Optional[List[str]] = None,
        restore_path: Optional[str] = None,
        trial_name_creator: Optional[Callable[["Trial"], str]] = None,
        trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
        log_to_file: Optional[str] = None,
        max_failures: int = 0,
        stub: bool = False,
        _setup_default_resource: bool = True,
    ):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.

        Args:
            _setup_default_resource: Whether to set up default resources.
                When initializing trials from checkpoints, this field is set to false,
                so that setting up default resources can be delayed till after
                ``trial.config`` is loaded from checkpoints.
        """
        # If this is set, trainables are not validated or looked up.
        # This can be used e.g. to initialize Trial objects from checkpoints
        # without loading the trainable first.
        self.stub = stub

        if not self.stub:
            validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        # Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        self.location = Location()
        trainable_cls = self.get_trainable_cls()
        if trainable_cls and _setup_default_resource:
            default_resources = trainable_cls.default_resource_request(
                self.config)

            # If Trainable returns resources, do not allow manual override via
            # `resources_per_trial` by the user.
            if default_resources:
                if resources or placement_group_factory:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))

                if isinstance(default_resources, PlacementGroupFactory):
                    placement_group_factory = default_resources
                    resources = None
                else:
                    placement_group_factory = None
                    resources = default_resources

        self.placement_group_factory = _to_pg_factory(resources,
                                                      placement_group_factory)

        self.stopping_criterion = stopping_criterion or {}

        self.log_to_file = log_to_file
        # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
        if (not self.log_to_file or not isinstance(self.log_to_file, Sequence)
                or not len(self.log_to_file) == 2):
            self.log_to_file = (None, None)

        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self._last_result = {}
        self._default_result_or_future: Union[ray.ObjectRef, dict, None] = None
        self.last_update_time = -float("inf")

        # stores in memory max/min/avg/last-n-avg/last result for each
        # metric by trial
        self.metric_analysis = {}

        # keep a moving average over these last n steps
        self.n_steps = [5, 10]
        self.metric_n_steps = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.last_debug = 0
        self.error_file = None
        self.pickled_error_file = None
        self.trial_name_creator = trial_name_creator
        self.trial_dirname_creator = trial_dirname_creator
        self.custom_trial_name = None
        self.custom_dirname = None

        # Checkpointing fields
        self.saving_to = None
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None

        if sync_function_tpl == "auto" or not isinstance(
                sync_function_tpl, str):
            sync_function_tpl = None
        self.sync_function_tpl = sync_function_tpl

        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.keep_checkpoints_num = keep_checkpoints_num
        self.checkpoint_score_attr = checkpoint_score_attr
        self.sync_on_checkpoint = sync_on_checkpoint
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num,
            checkpoint_score_attr,
            CheckpointDeleter(self._trainable_name(), self.runner),
        )

        # Restoration fields
        self.restore_path = restore_path
        self.restoring_from = None
        self.num_failures = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

        if trial_dirname_creator:
            self.custom_dirname = trial_dirname_creator(self)
            if os.path.sep in self.custom_dirname:
                raise ValueError(
                    f"Trial dirname must not contain '/'. Got {self.custom_dirname}"
                )

        self._state_json = None
        self._state_valid = False

    def _get_default_result_or_future(self) -> Optional[dict]:
        """Calls ray.get on self._default_result_or_future and assigns back.

        Returns None in case of exceptions.
        Will also set the trial location if runner is set.
        """
        if self._default_result_or_future and isinstance(
                self._default_result_or_future, ray.ObjectRef):
            try:
                self._default_result_or_future = ray.get(
                    self._default_result_or_future)
            except RayActorError:  # error during initialization
                self._default_result_or_future = None
        if self._default_result_or_future and self.runner:
            self.set_location(
                Location(
                    self._default_result_or_future.get(NODE_IP),
                    self._default_result_or_future.get(PID),
                ))
        return self._default_result_or_future

    @property
    def last_result(self) -> dict:
        # The logic in here is as follows:
        # 1. If the trial has reported at least once, last_result would have
        #    been set and therefore would not be empty. We can just return it.
        # 2. If the trial has not reported at least once but we have the
        #    future for the default results dict, (obtained through
        #    Trainable.get_auto_filled_metrics), we get that future
        #    and return it.
        # 3. In the worst case where we have nothing, we just set the
        #    trial_id and return that.
        result = self._last_result
        if not {k for k in result if k != TRIAL_ID}:
            self._get_default_result_or_future()
            result = self._default_result_or_future or result
        result.setdefault(TRIAL_ID, self.trial_id)
        return result

    @last_result.setter
    def last_result(self, val: dict):
        self._last_result = val

    @property
    def has_reported_at_least_once(self) -> bool:
        return bool(self._last_result)

    @property
    def node_ip(self):
        return self.location.hostname

    @property
    def checkpoint(self):
        """Returns the most recent checkpoint.

        If the trial is in ERROR state, the most recent PERSISTENT checkpoint
        is returned.
        """
        if self.status == Trial.ERROR:
            checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
        else:
            checkpoint = self.checkpoint_manager.newest_checkpoint
        if checkpoint.value is None:
            checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                         self.restore_path)
        return checkpoint

    @classmethod
    def generate_id(cls):
        return str(uuid.uuid1().hex)[:8]

    @property
    def remote_checkpoint_dir(self):
        """This is the **per trial** remote checkpoint dir.

        This is different from **per experiment** remote checkpoint dir.
        """
        assert self.logdir, "Trial {}: logdir not initialized.".format(self)
        if not self.remote_checkpoint_dir_prefix:
            return None
        logdir_name = os.path.basename(self.logdir)
        return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)

    @property
    def uses_cloud_checkpointing(self):
        return bool(self.remote_checkpoint_dir)

    def reset(self):
        # If there is `default_resource_request` associated with the trainable,
        # clear `resources` and `placement_group_factory`.
        # This is mainly relevant for RLlib tuning jobs, where we save users
        # of the trouble to specify the resources themselves by having some
        # default resources for popular RLlib algorithms.
        trainable_cls = self.get_trainable_cls()
        clear_resources = trainable_cls and trainable_cls.default_resource_request(
            self.config)
        placement_group_factory = (self.placement_group_factory
                                   if not clear_resources else None)

        return Trial(
            self.trainable_name,
            config=self.config,
            trial_id=None,
            local_dir=self.local_dir,
            evaluated_params=self.evaluated_params,
            experiment_tag=self.experiment_tag,
            resources=None,
            placement_group_factory=placement_group_factory,
            stopping_criterion=self.stopping_criterion,
            remote_checkpoint_dir=self.remote_checkpoint_dir,
            checkpoint_freq=self.checkpoint_freq,
            checkpoint_at_end=self.checkpoint_at_end,
            sync_on_checkpoint=self.sync_on_checkpoint,
            keep_checkpoints_num=self.keep_checkpoints_num,
            checkpoint_score_attr=self.checkpoint_score_attr,
            export_formats=self.export_formats,
            restore_path=self.restore_path,
            trial_name_creator=self.trial_name_creator,
            trial_dirname_creator=self.trial_dirname_creator,
            log_to_file=self.log_to_file,
            max_failures=self.max_failures,
        )

    def init_logdir(self):
        """Init logdir."""
        if not self.logdir:
            self.logdir = create_logdir(self._generate_dirname(),
                                        self.local_dir)
        else:
            os.makedirs(self.logdir, exist_ok=True)

        self.invalidate_json_state()

    def update_resources(self, resources: Union[Dict, PlacementGroupFactory]):
        """EXPERIMENTAL: Updates the resource requirements.

        Should only be called when the trial is not running.

        Raises:
            ValueError if trial status is running.
        """
        if self.status is Trial.RUNNING:
            raise ValueError("Cannot update resources while Trial is running.")

        placement_group_factory = None
        if isinstance(resources, PlacementGroupFactory):
            placement_group_factory = resources
        else:
            resources = Resources(**resources)

        self.placement_group_factory = _to_pg_factory(resources,
                                                      placement_group_factory)

        self.invalidate_json_state()

    def set_runner(self, runner):
        self.runner = runner
        if runner:
            # Do not block here, the result will be gotten when last_result
            # property is accessed
            self._default_result_or_future = runner.get_auto_filled_metrics.remote(
                debug_metrics_only=True)
        self.checkpoint_manager.delete = CheckpointDeleter(
            self._trainable_name(), runner)
        # No need to invalidate state cache: runner is not stored in json
        # self.invalidate_json_state()

    def set_location(self, location):
        """Sets the location of the trial."""
        self.location = location
        # No need to invalidate state cache: location is not stored in json
        # self.invalidate_json_state()

    def set_status(self, status):
        """Sets the status of the trial."""
        self.status = status
        if status == Trial.RUNNING:
            if self.start_time is None:
                self.start_time = time.time()
        self.invalidate_json_state()

    def set_config(self, config):
        self.config = config
        self.invalidate_json_state()

    def set_experiment_tag(self, experiment_tag):
        self.experiment_tag = experiment_tag
        self.invalidate_json_state()

    def write_error_log(self,
                        exc: Optional[Union[TuneError, RayTaskError]] = None):
        if exc and self.logdir:
            self.num_failures += 1
            self.error_file = os.path.join(self.logdir, "error.txt")
            if exc and isinstance(exc, RayTaskError):
                # Piping through the actual error to result grid.
                self.pickled_error_file = os.path.join(self.logdir,
                                                       "error.pkl")
                with open(self.pickled_error_file, "wb") as f:
                    cloudpickle.dump(exc, f)
            with open(self.error_file, "a+") as f:
                f.write("Failure # {} (occurred at {})\n".format(
                    self.num_failures, date_str()))
                f.write(str(exc) + "\n")
        self.invalidate_json_state()

    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""
        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result dict. Keys "
                    "are {}.".format(criteria, list(result.keys())))
            elif isinstance(criteria, dict):
                raise ValueError(
                    "Stopping criteria is now flattened by default. "
                    "Use forward slashes to nest values `key1/key2/key3`.")
            elif result[criteria] >= stop_value:
                return True
        return False

    def should_checkpoint(self):
        """Whether this trial is due for checkpointing."""
        result = self.last_result or {}
        if result.get(DONE) and self.checkpoint_at_end:
            return True
        return (self.checkpoint_freq and
                result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0)

    def has_checkpoint(self):
        return self.checkpoint.value is not None

    def clear_checkpoint(self):
        self.checkpoint.value = None
        self.restoring_from = None
        self.invalidate_json_state()

    def on_checkpoint(self, checkpoint: _TuneCheckpoint):
        """Hook for handling checkpoints taken by the Trainable.

        Args:
            checkpoint: Checkpoint taken.
        """
        self.checkpoint_manager.on_checkpoint(checkpoint)
        self.invalidate_json_state()

    def on_restore(self):
        """Handles restoration completion."""
        assert self.is_restoring
        self.last_result = self.restoring_from.result
        self.restoring_from = None
        self.invalidate_json_state()

    def should_recover(self):
        """Returns whether the trial qualifies for retrying.

        This is if the trial has not failed more than max_failures. Note this
        may return true even when there is no checkpoint, either because
        `self.checkpoint_freq` is `0` or because the trial failed before
        a checkpoint has been made.
        """
        return self.num_failures < self.max_failures or self.max_failures < 0

    def update_last_result(self, result):
        if self.experiment_tag:
            result.update(experiment_tag=self.experiment_tag)

        self.set_location(Location(result.get(NODE_IP), result.get(PID)))
        self.last_result = result
        self.last_update_time = time.time()

        metric_result = self.last_result.copy()
        for remove_metric in DEBUG_METRICS:
            metric_result.pop(remove_metric, None)

        for metric, value in flatten_dict(metric_result).items():
            if isinstance(value, Number):
                if metric not in self.metric_analysis:
                    self.metric_analysis[metric] = {
                        "max": value,
                        "min": value,
                        "avg": value,
                        "last": value,
                    }
                    self.metric_n_steps[metric] = {}
                    for n in self.n_steps:
                        key = "last-{:d}-avg".format(n)
                        self.metric_analysis[metric][key] = value
                        # Store n as string for correct restore.
                        self.metric_n_steps[metric][str(n)] = deque([value],
                                                                    maxlen=n)
                else:
                    step = result["training_iteration"] or 1
                    self.metric_analysis[metric]["max"] = max(
                        value, self.metric_analysis[metric]["max"])
                    self.metric_analysis[metric]["min"] = min(
                        value, self.metric_analysis[metric]["min"])
                    self.metric_analysis[metric]["avg"] = (
                        1 / step *
                        (value +
                         (step - 1) * self.metric_analysis[metric]["avg"]))
                    self.metric_analysis[metric]["last"] = value

                    for n in self.n_steps:
                        key = "last-{:d}-avg".format(n)
                        self.metric_n_steps[metric][str(n)].append(value)
                        self.metric_analysis[metric][key] = sum(
                            self.metric_n_steps[metric][str(n)]) / len(
                                self.metric_n_steps[metric][str(n)])
        self.invalidate_json_state()

    def get_trainable_cls(self):
        if self.stub:
            return None
        return get_trainable_cls(self.trainable_name)

    def is_finished(self):
        return self.status in [Trial.ERROR, Trial.TERMINATED]

    @property
    def is_restoring(self):
        return self.restoring_from is not None

    @property
    def is_saving(self):
        return self.saving_to is not None

    def __repr__(self):
        return self._trainable_name(include_trial_id=True)

    def __str__(self):
        return self._trainable_name(include_trial_id=True)

    def _trainable_name(self, include_trial_id=False):
        """Combines ``env`` with ``trainable_name`` and ``trial_id``.

        Can be overridden with a custom string creator.
        """
        if self.custom_trial_name:
            return self.custom_trial_name

        if "env" in self.config:
            env = self.config["env"]
            if isinstance(env, type):
                env = env.__name__
            identifier = "{}_{}".format(self.trainable_name, env)
        else:
            identifier = self.trainable_name
        if include_trial_id:
            identifier += "_" + self.trial_id
        return identifier.replace("/", "_")

    def _generate_dirname(self):
        if self.custom_dirname:
            generated_dirname = self.custom_dirname
        else:
            MAX_LEN_IDENTIFIER = int(
                os.environ.get("TUNE_MAX_LEN_IDENTIFIER", "130"))
            generated_dirname = f"{str(self)}_{self.experiment_tag}"
            generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
            generated_dirname += f"_{date_str()}"
        # This is the file path used by rsync. ['/', '(', ')'] are not allowed.
        return re.sub("[/()]", "_", generated_dirname)

    def invalidate_json_state(self):
        self._state_valid = False

    def get_json_state(self) -> str:
        if not self._state_json or not self._state_valid:
            json_state = json.dumps(self.__getstate__(),
                                    indent=2,
                                    cls=TuneFunctionEncoder)
            self._state_json = json_state
            self._state_valid = True
        return self._state_json

    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING.
        Note this can only occur if the trial holds a PERSISTENT checkpoint.
        """
        state = self.__dict__.copy()

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["location"] = Location()
        # Avoid waiting for events that will never occur on resume.
        state["restoring_from"] = None
        state["saving_to"] = None

        state["_state_json"] = None
        state["_state_valid"] = False
        state["_default_result_or_future"] = None

        return copy.deepcopy(state)

    def __setstate__(self, state):

        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        # Ensure that stub doesn't get overriden
        stub = state.pop("stub", True)
        self.__dict__.update(state)
        self.stub = stub or getattr(self, "stub", False)

        if not self.stub:
            validate_trainable(self.trainable_name)

        assert self.placement_group_factory

        # Avoid creating logdir in client mode for returned trial results,
        # since the dir might not be creatable locally.
        # TODO(ekl) this is kind of a hack.
        if not ray.util.client.ray.is_connected():
            self.init_logdir()  # Create logdir if it does not exist
Esempio n. 9
0
 def checkpoint_manager(self, keep_checkpoints_num):
     return CheckpointManager(
         keep_checkpoints_num, "i", delete_fn=lambda c: None)
Esempio n. 10
0
    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 evaluated_params=None,
                 experiment_tag="",
                 resources=None,
                 placement_group_factory=None,
                 stopping_criterion=None,
                 remote_checkpoint_dir=None,
                 sync_to_cloud=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=TRAINING_ITERATION,
                 export_formats=None,
                 restore_path=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 log_to_file=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """
        validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        #: Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        trainable_cls = self.get_trainable_cls()
        if trainable_cls:
            default_resources = trainable_cls.default_resource_request(
                self.config)

            # If Trainable returns resources, do not allow manual override via
            # `resources_per_trial` by the user.
            if default_resources:
                if resources or placement_group_factory:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))

                # New way: Trainable returns a PlacementGroupFactory object.
                if isinstance(default_resources, PlacementGroupFactory):
                    placement_group_factory = default_resources
                    resources = None
                # Set placement group factory to None for backwards
                # compatibility.
                else:
                    placement_group_factory = None
                    resources = default_resources
        self.location = Location()

        self.resources = resources or Resources(cpu=1, gpu=0)
        self.placement_group_factory = placement_group_factory
        self._setup_resources()

        self.stopping_criterion = stopping_criterion or {}

        self.log_to_file = log_to_file
        # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
        if not self.log_to_file or not isinstance(self.log_to_file, Sequence) \
           or not len(self.log_to_file) == 2:
            self.log_to_file = (None, None)

        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self._last_result = {}
        self._default_result_or_future: Union[ray.ObjectRef, dict,
                                              None] = (None)
        self.last_update_time = -float("inf")

        # stores in memory max/min/avg/last-n-avg/last result for each
        # metric by trial
        self.metric_analysis = {}

        # keep a moving average over these last n steps
        self.n_steps = [5, 10]
        self.metric_n_steps = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.last_debug = 0
        self.error_file = None
        self.error_msg = None
        self.trial_name_creator = trial_name_creator
        self.custom_trial_name = None
        self.custom_dirname = None

        # Checkpointing fields
        self.saving_to = None
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None
        self.sync_to_cloud = sync_to_cloud
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.keep_checkpoints_num = keep_checkpoints_num
        self.checkpoint_score_attr = checkpoint_score_attr
        self.sync_on_checkpoint = sync_on_checkpoint
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num, checkpoint_score_attr,
            CheckpointDeleter(self._trainable_name(), self.runner,
                              self.node_ip))

        # Restoration fields
        self.restore_path = restore_path
        self.restoring_from = None
        self.num_failures = 0
        self.has_new_resources = False

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

        if trial_dirname_creator:
            self.custom_dirname = trial_dirname_creator(self)
            if os.path.sep in self.custom_dirname:
                raise ValueError("Trial dirname must not contain '/'. "
                                 "Got {self.custom_dirname}")

        self._state_json = None
        self._state_valid = False
Esempio n. 11
0
class Trial:
    """A trial object holds the state for one model training run.

    Trials are themselves managed by the TrialRunner class, which implements
    the event loop for submitting trial runs to a Ray cluster.

    Trials start in the PENDING state, and transition to RUNNING once started.
    On error it transitions to ERROR, otherwise TERMINATED on success.

    Attributes:
        trainable_name (str): Name of the trainable object to be executed.
        config (dict): Provided configuration dictionary with evaluated params.
        trial_id (str): Unique identifier for the trial.
        local_dir (str): Local_dir as passed to tune.run.
        logdir (str): Directory where the trial logs are saved.
        evaluated_params (dict): Evaluated parameters by search algorithm,
        experiment_tag (str): Identifying trial name to show in the console.
        resources (Resources): Amount of resources that this trial will use.
        status (str): One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
        error_file (str): Path to the errors that this trial has raised.

    """

    _nonjson_fields = [
        "results",
        "best_result",
        "param_config",
        "extra_arg",
    ]

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    PAUSED = "PAUSED"
    TERMINATED = "TERMINATED"
    ERROR = "ERROR"

    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 evaluated_params=None,
                 experiment_tag="",
                 resources=None,
                 placement_group_factory=None,
                 stopping_criterion=None,
                 remote_checkpoint_dir=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=TRAINING_ITERATION,
                 export_formats=None,
                 restore_path=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 log_to_file=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """
        validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        #: Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        trainable_cls = self.get_trainable_cls()
        if trainable_cls:
            default_resources = trainable_cls.default_resource_request(
                self.config)

            # If Trainable returns resources, do not allow manual override via
            # `resources_per_trial` by the user.
            if default_resources:
                if resources or placement_group_factory:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))

                # New way: Trainable returns a PlacementGroupFactory object.
                if isinstance(default_resources, PlacementGroupFactory):
                    placement_group_factory = default_resources
                    resources = None
                # Set placement group factory to None for backwards
                # compatibility.
                else:
                    placement_group_factory = None
                    resources = default_resources
        self.location = Location()

        self.resources = resources or Resources(cpu=1, gpu=0)
        self.placement_group_factory = placement_group_factory
        self._setup_resources()

        self.stopping_criterion = stopping_criterion or {}

        self.log_to_file = log_to_file
        # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
        if not self.log_to_file or not isinstance(self.log_to_file, Sequence) \
           or not len(self.log_to_file) == 2:
            self.log_to_file = (None, None)

        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")

        # stores in memory max/min/avg/last-n-avg/last result for each
        # metric by trial
        self.metric_analysis = {}

        # keep a moving average over these last n steps
        self.n_steps = [5, 10]
        self.metric_n_steps = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.last_debug = 0
        self.error_file = None
        self.error_msg = None
        self.trial_name_creator = trial_name_creator
        self.custom_trial_name = None
        self.custom_dirname = None

        # Checkpointing fields
        self.saving_to = None
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.keep_checkpoints_num = keep_checkpoints_num
        self.checkpoint_score_attr = checkpoint_score_attr
        self.sync_on_checkpoint = sync_on_checkpoint
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num, checkpoint_score_attr,
            checkpoint_deleter(self._trainable_name(), self.runner))

        # Restoration fields
        self.restore_path = restore_path
        self.restoring_from = None
        self.num_failures = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

        if trial_dirname_creator:
            self.custom_dirname = trial_dirname_creator(self)
            if os.path.sep in self.custom_dirname:
                raise ValueError(f"Trial dirname must not contain '/'. "
                                 "Got {self.custom_dirname}")

        self._state_json = None
        self._state_valid = False

    def _setup_resources(self, log_always: bool = False):
        """Set up resource and placement group requirements.

        This will try to convert the resource request in ``self.resources``
        to a placement group factory object. If this is unsuccessful,
        placement groups will not be used.

        Args:
            log_always (bool): If True, this will always log a warning if
                conversion from a resource dict to a placement group
                definition was unsuccessful (e.g. when passing ``extra_``
                requests).


        """
        if not self.placement_group_factory and \
           not int(os.getenv("TUNE_PLACEMENT_GROUP_AUTO_DISABLED", "0")):
            try:
                self.placement_group_factory = resource_dict_to_pg_factory(
                    self.resources)
            except ValueError as exc:
                if log_always or log_once("tune_pg_extra_resources"):
                    logger.warning(exc)
                self.placement_group_factory = None

        # Set placement group factory flag to True in Resources object.
        if self.placement_group_factory:
            resource_kwargs = self.resources._asdict()
            resource_kwargs["has_placement_group"] = True
            self.resources = Resources(**resource_kwargs)

    @property
    def node_ip(self):
        return self.location.hostname

    @property
    def checkpoint(self):
        """Returns the most recent checkpoint.

        If the trial is in ERROR state, the most recent PERSISTENT checkpoint
        is returned.
        """
        if self.status == Trial.ERROR:
            checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
        else:
            checkpoint = self.checkpoint_manager.newest_checkpoint
        if checkpoint.value is None:
            checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path)
        return checkpoint

    @classmethod
    def generate_id(cls):
        return str(uuid.uuid1().hex)[:8]

    @property
    def remote_checkpoint_dir(self):
        assert self.logdir, "Trial {}: logdir not initialized.".format(self)
        if not self.remote_checkpoint_dir_prefix:
            return None
        logdir_name = os.path.basename(self.logdir)
        return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)

    @property
    def uses_placement_groups(self):
        return bool(self.placement_group_factory)

    def reset(self):
        return Trial(
            self.trainable_name,
            config=self.config,
            trial_id=None,
            local_dir=self.local_dir,
            evaluated_params=self.evaluated_params,
            experiment_tag=self.experiment_tag,
            resources=self.resources,
            placement_group_factory=self.placement_group_factory,
            stopping_criterion=self.stopping_criterion,
            remote_checkpoint_dir=self.remote_checkpoint_dir,
            checkpoint_freq=self.checkpoint_freq,
            checkpoint_at_end=self.checkpoint_at_end,
            sync_on_checkpoint=self.sync_on_checkpoint,
            keep_checkpoints_num=self.keep_checkpoints_num,
            checkpoint_score_attr=self.checkpoint_score_attr,
            export_formats=self.export_formats,
            restore_path=self.restore_path,
            trial_name_creator=self.trial_name_creator,
            log_to_file=self.log_to_file,
            max_failures=self.max_failures,
        )

    def init_logdir(self):
        """Init logdir."""
        if not self.logdir:
            self.logdir = create_logdir(self._generate_dirname(),
                                        self.local_dir)
        else:
            os.makedirs(self.logdir, exist_ok=True)
        self.invalidate_json_state()

    def update_resources(self, resources: Union[Dict, Callable,
                                                PlacementGroupFactory]):
        """EXPERIMENTAL: Updates the resource requirements.

        Should only be called when the trial is not running.

        Raises:
            ValueError if trial status is running.
        """
        if self.status is Trial.RUNNING:
            raise ValueError("Cannot update resources while Trial is running.")

        if isinstance(resources, PlacementGroupFactory):
            self.placement_group_factory = resources
        else:
            self.resources = Resources(**resources)

        self._setup_resources()

        self.invalidate_json_state()

    def set_runner(self, runner):
        self.runner = runner
        self.checkpoint_manager.delete = checkpoint_deleter(
            self._trainable_name(), runner)
        # No need to invalidate state cache: runner is not stored in json
        # self.invalidate_json_state()

    def set_location(self, location):
        """Sets the location of the trial."""
        self.location = location
        # No need to invalidate state cache: location is not stored in json
        # self.invalidate_json_state()

    def set_status(self, status):
        """Sets the status of the trial."""
        self.status = status
        if status == Trial.RUNNING:
            if self.start_time is None:
                self.start_time = time.time()
        self.invalidate_json_state()

    def set_config(self, config):
        self.config = config
        self.invalidate_json_state()

    def set_experiment_tag(self, experiment_tag):
        self.experiment_tag = experiment_tag
        self.invalidate_json_state()

    def write_error_log(self, error_msg):
        if error_msg and self.logdir:
            self.num_failures += 1
            self.error_file = os.path.join(self.logdir, "error.txt")
            with open(self.error_file, "a+") as f:
                f.write("Failure # {} (occurred at {})\n".format(
                    self.num_failures, date_str()))
                f.write(error_msg + "\n")
            self.error_msg = error_msg
        self.invalidate_json_state()

    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""
        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            elif isinstance(criteria, dict):
                raise ValueError(
                    "Stopping criteria is now flattened by default. "
                    "Use forward slashes to nest values `key1/key2/key3`.")
            elif result[criteria] >= stop_value:
                return True
        return False

    def should_checkpoint(self):
        """Whether this trial is due for checkpointing."""
        result = self.last_result or {}
        if result.get(DONE) and self.checkpoint_at_end:
            return True
        return (self.checkpoint_freq and
                result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0)

    def has_checkpoint(self):
        return self.checkpoint.value is not None

    def clear_checkpoint(self):
        self.checkpoint.value = None
        self.restoring_from = None
        self.invalidate_json_state()

    def on_checkpoint(self, checkpoint):
        """Hook for handling checkpoints taken by the Trainable.

        Args:
            checkpoint (Checkpoint): Checkpoint taken.
        """
        self.checkpoint_manager.on_checkpoint(checkpoint)
        self.invalidate_json_state()

    def on_restore(self):
        """Handles restoration completion."""
        assert self.is_restoring
        self.last_result = self.restoring_from.result
        self.restoring_from = None
        self.invalidate_json_state()

    def should_recover(self):
        """Returns whether the trial qualifies for retrying.

        This is if the trial has not failed more than max_failures. Note this
        may return true even when there is no checkpoint, either because
        `self.checkpoint_freq` is `0` or because the trial failed before
        a checkpoint has been made.
        """
        return self.num_failures < self.max_failures or self.max_failures < 0

    def update_last_result(self, result, terminate=False):
        if self.experiment_tag:
            result.update(experiment_tag=self.experiment_tag)

        self.set_location(Location(result.get("node_ip"), result.get("pid")))
        self.last_result = result
        self.last_update_time = time.time()

        for metric, value in flatten_dict(result).items():
            if isinstance(value, Number):
                if metric not in self.metric_analysis:
                    self.metric_analysis[metric] = {
                        "max": value,
                        "min": value,
                        "avg": value,
                        "last": value
                    }
                    self.metric_n_steps[metric] = {}
                    for n in self.n_steps:
                        key = "last-{:d}-avg".format(n)
                        self.metric_analysis[metric][key] = value
                        # Store n as string for correct restore.
                        self.metric_n_steps[metric][str(n)] = deque([value],
                                                                    maxlen=n)
                else:
                    step = result["training_iteration"] or 1
                    self.metric_analysis[metric]["max"] = max(
                        value, self.metric_analysis[metric]["max"])
                    self.metric_analysis[metric]["min"] = min(
                        value, self.metric_analysis[metric]["min"])
                    self.metric_analysis[metric]["avg"] = 1 / step * (
                        value +
                        (step - 1) * self.metric_analysis[metric]["avg"])
                    self.metric_analysis[metric]["last"] = value

                    for n in self.n_steps:
                        key = "last-{:d}-avg".format(n)
                        self.metric_n_steps[metric][str(n)].append(value)
                        self.metric_analysis[metric][key] = sum(
                            self.metric_n_steps[metric][str(n)]) / len(
                                self.metric_n_steps[metric][str(n)])
        self.invalidate_json_state()

    def get_trainable_cls(self):
        return get_trainable_cls(self.trainable_name)

    def is_finished(self):
        return self.status in [Trial.ERROR, Trial.TERMINATED]

    @property
    def is_restoring(self):
        return self.restoring_from is not None

    @property
    def is_saving(self):
        return self.saving_to is not None

    def __repr__(self):
        return self._trainable_name(include_trial_id=True)

    def __str__(self):
        return self._trainable_name(include_trial_id=True)

    def _trainable_name(self, include_trial_id=False):
        """Combines ``env`` with ``trainable_name`` and ``trial_id``.

        Can be overridden with a custom string creator.
        """
        if self.custom_trial_name:
            return self.custom_trial_name

        if "env" in self.config:
            env = self.config["env"]
            if isinstance(env, type):
                env = env.__name__
            identifier = "{}_{}".format(self.trainable_name, env)
        else:
            identifier = self.trainable_name
        if include_trial_id:
            identifier += "_" + self.trial_id
        return identifier.replace("/", "_")

    def _generate_dirname(self):
        if self.custom_dirname:
            generated_dirname = self.custom_dirname
        else:
            generated_dirname = f"{str(self)}_{self.experiment_tag}"
            generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
            generated_dirname += f"_{date_str()}"
        return generated_dirname.replace("/", "_")

    def invalidate_json_state(self):
        self._state_valid = False

    def get_json_state(self) -> str:
        if not self._state_json or not self._state_valid:
            json_state = json.dumps(self.__getstate__(),
                                    indent=2,
                                    cls=TuneFunctionEncoder)
            self._state_json = json_state
            self._state_valid = True
        return self._state_json

    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING.
        Note this can only occur if the trial holds a PERSISTENT checkpoint.
        """
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["location"] = Location()
        # Avoid waiting for events that will never occur on resume.
        state["restoring_from"] = None
        state["saving_to"] = None

        state["_state_json"] = None
        state["_state_valid"] = False

        return copy.deepcopy(state)

    def __setstate__(self, state):
        state["resources"] = json_to_resources(state["resources"])

        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        validate_trainable(self.trainable_name)

        # Avoid creating logdir in client mode for returned trial results,
        # since the dir might not be creatable locally. TODO(ekl) thsi is kind
        # of a hack.
        if not ray.util.client.ray.is_connected():
            self.init_logdir()  # Create logdir if it does not exist
Esempio n. 12
0
    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 evaluated_params=None,
                 experiment_tag="",
                 resources=None,
                 stopping_criterion=None,
                 remote_checkpoint_dir=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=TRAINING_ITERATION,
                 export_formats=None,
                 restore_path=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=None,
                 sync_to_driver_fn=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """
        validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        #: Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        trainable_cls = self.get_trainable_cls()
        if trainable_cls:
            default_resources = trainable_cls.default_resource_request(
                self.config)
            if default_resources:
                if resources:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))
                resources = default_resources
        self.location = Location()
        self.resources = resources or Resources(cpu=1, gpu=0)
        self.stopping_criterion = stopping_criterion or {}
        self.loggers = loggers

        self.log_to_file = log_to_file
        # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
        if not self.log_to_file or not isinstance(self.log_to_file, Sequence) \
           or not len(self.log_to_file) == 2:
            self.log_to_file = (None, None)

        self.sync_to_driver_fn = sync_to_driver_fn
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")

        # stores in memory max/min/avg/last-n-avg/last result for each
        # metric by trial
        self.metric_analysis = {}

        # keep a moving average over these last n steps
        self.n_steps = [5, 10]
        self.metric_n_steps = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.result_logger = None
        self.last_debug = 0
        self.error_file = None
        self.error_msg = None
        self.trial_name_creator = trial_name_creator
        self.custom_trial_name = None
        self.custom_dirname = None

        # Checkpointing fields
        self.saving_to = None
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.keep_checkpoints_num = keep_checkpoints_num
        self.checkpoint_score_attr = checkpoint_score_attr
        self.sync_on_checkpoint = sync_on_checkpoint
        self.checkpoint_manager = CheckpointManager(
            keep_checkpoints_num, checkpoint_score_attr,
            checkpoint_deleter(self._trainable_name(), self.runner))

        # Restoration fields
        self.restore_path = restore_path
        self.restoring_from = None
        self.num_failures = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        self._nonjson_fields = [
            "loggers",
            "sync_to_driver_fn",
            "results",
            "best_result",
            "param_config",
            "extra_arg",
        ]
        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

        if trial_dirname_creator:
            self.custom_dirname = trial_dirname_creator(self)
            if os.path.sep in self.custom_dirname:
                raise ValueError(f"Trial dirname must not contain '/'. "
                                 "Got {self.custom_dirname}")