示例#1
0
    def testSameCheckpoint(self):
        checkpoint_manager = _CheckpointManager(
            1, "i", delete_fn=lambda c: os.remove(c.value))

        tmpfiles = []
        for i in range(3):
            _, tmpfile = tempfile.mkstemp()
            with open(tmpfile, "wt") as fp:
                fp.write("")
            tmpfiles.append(tmpfile)

        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, tmpfiles[0],
                            self.mock_result(5, 5)),
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, tmpfiles[1],
                            self.mock_result(10, 10)),
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, tmpfiles[2],
                            self.mock_result(0, 0)),
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, tmpfiles[1],
                            self.mock_result(20, 20)),
        ]
        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)
            self.assertTrue(os.path.exists(checkpoint.value))

        for tmpfile in tmpfiles:
            if os.path.exists(tmpfile):
                os.remove(tmpfile)
示例#2
0
    def testSameCheckpoint(self):
        checkpoint_manager = _CheckpointManager(
            keep_checkpoints_num=1,
            checkpoint_score_attr="i",
            delete_fn=lambda c: os.remove(c.dir_or_data),
        )

        tmpfiles = []
        for i in range(3):
            _, tmpfile = tempfile.mkstemp()
            with open(tmpfile, "wt") as fp:
                fp.write("")
            tmpfiles.append(tmpfile)

        checkpoints = [
            _TrackedCheckpoint(
                dir_or_data=tmpfiles[0],
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=self.mock_result(5, 5),
            ),
            _TrackedCheckpoint(
                dir_or_data=tmpfiles[1],
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=self.mock_result(10, 10),
            ),
            _TrackedCheckpoint(
                dir_or_data=tmpfiles[2],
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=self.mock_result(0, 0),
            ),
            _TrackedCheckpoint(
                dir_or_data=tmpfiles[1],
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=self.mock_result(20, 20),
            ),
        ]
        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)
            self.assertTrue(os.path.exists(checkpoint.dir_or_data))

        for tmpfile in tmpfiles:
            if os.path.exists(tmpfile):
                os.remove(tmpfile)
示例#3
0
    def __init__(
        self,
        trainable_name: str,
        config: Optional[Dict] = None,
        trial_id: Optional[str] = None,
        local_dir: Optional[str] = DEFAULT_RESULTS_DIR,
        evaluated_params: Optional[Dict] = None,
        experiment_tag: str = "",
        resources: Optional[Resources] = None,
        placement_group_factory: Optional[PlacementGroupFactory] = None,
        stopping_criterion: Optional[Dict[str, float]] = None,
        remote_checkpoint_dir: Optional[str] = None,
        sync_function_tpl: Optional[str] = None,
        checkpoint_freq: int = 0,
        checkpoint_at_end: bool = False,
        sync_on_checkpoint: bool = True,
        keep_checkpoints_num: Optional[int] = None,
        checkpoint_score_attr: str = TRAINING_ITERATION,
        export_formats: Optional[List[str]] = None,
        restore_path: Optional[str] = None,
        trial_name_creator: Optional[Callable[["Trial"], str]] = None,
        trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
        log_to_file: Optional[str] = None,
        max_failures: int = 0,
        stub: bool = False,
        _setup_default_resource: bool = True,
    ):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.

        Args:
            _setup_default_resource: Whether to set up default resources.
                When initializing trials from checkpoints, this field is set to false,
                so that setting up default resources can be delayed till after
                ``trial.config`` is loaded from checkpoints.
        """
        # If this is set, trainables are not validated or looked up.
        # This can be used e.g. to initialize Trial objects from checkpoints
        # without loading the trainable first.
        self.stub = stub

        if not self.stub:
            validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        # Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        self.location = _Location()
        trainable_cls = self.get_trainable_cls()
        if trainable_cls and _setup_default_resource:
            default_resources = trainable_cls.default_resource_request(self.config)

            # If Trainable returns resources, do not allow manual override via
            # `resources_per_trial` by the user.
            if default_resources:
                if resources or placement_group_factory:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources
                        )
                    )

                if isinstance(default_resources, PlacementGroupFactory):
                    placement_group_factory = default_resources
                    resources = None
                else:
                    placement_group_factory = None
                    resources = default_resources

        self.placement_group_factory = _to_pg_factory(
            resources, placement_group_factory
        )

        self.stopping_criterion = stopping_criterion or {}

        self.log_to_file = log_to_file
        # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
        if (
            not self.log_to_file
            or not isinstance(self.log_to_file, Sequence)
            or not len(self.log_to_file) == 2
        ):
            self.log_to_file = (None, None)

        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self._last_result = {}
        self._default_result_or_future: Union[ray.ObjectRef, dict, None] = None
        self.last_update_time = -float("inf")

        # stores in memory max/min/avg/last-n-avg/last result for each
        # metric by trial
        self.metric_analysis = {}

        # keep a moving average over these last n steps
        self.n_steps = [5, 10]
        self.metric_n_steps = {}

        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.start_time = None
        self.logdir = None
        self.runner = None
        self.last_debug = 0
        self.error_file = None
        self.pickled_error_file = None
        self.trial_name_creator = trial_name_creator
        self.trial_dirname_creator = trial_dirname_creator
        self.custom_trial_name = None
        self.custom_dirname = None

        # Checkpointing fields
        self.saving_to = None
        if remote_checkpoint_dir:
            self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
        else:
            self.remote_checkpoint_dir_prefix = None

        if sync_function_tpl == "auto" or not isinstance(sync_function_tpl, str):
            sync_function_tpl = None
        self.sync_function_tpl = sync_function_tpl

        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self.keep_checkpoints_num = keep_checkpoints_num
        self.checkpoint_score_attr = checkpoint_score_attr
        self.sync_on_checkpoint = sync_on_checkpoint
        self.checkpoint_manager = _CheckpointManager(
            keep_checkpoints_num,
            checkpoint_score_attr,
            _CheckpointDeleter(self._trainable_name(), self.runner),
        )

        # Restoration fields
        self.restore_path = restore_path
        self.restoring_from = None
        self.num_failures = 0

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

        if trial_dirname_creator:
            self.custom_dirname = trial_dirname_creator(self)
            if os.path.sep in self.custom_dirname:
                raise ValueError(
                    f"Trial dirname must not contain '/'. Got {self.custom_dirname}"
                )

        self._state_json = None
        self._state_valid = False
示例#4
0
 def checkpoint_manager(self, keep_checkpoints_num):
     return _CheckpointManager(
         keep_checkpoints_num=keep_checkpoints_num,
         checkpoint_score_attr="i",
         delete_fn=lambda c: None,
     )
示例#5
0
 def checkpoint_manager(self, keep_checkpoints_num):
     return _CheckpointManager(keep_checkpoints_num,
                               "i",
                               delete_fn=lambda c: None)