예제 #1
0
파일: track.py 프로젝트: vincehass/orion
    def get_trial(self, trial=None, uid=None):
        """See :func:`~orion.storage.BaseStorageProtocol.get_trials`"""
        if trial is not None and uid is not None:
            assert trial.id == uid

        if uid is None:
            if trial is None:
                raise MissingArguments(
                    'trial or uid argument should be populated')

            uid = trial.id

        _hash, _rev = 0, 0
        data = uid.split('_', maxsplit=1)

        if len(data) == 1:
            _hash = data[0]

        elif len(data) == 2:
            _hash, _rev = data

        trials = self.backend.get_trial(TrackTrial(_hash=_hash, revision=_rev))

        if trials is None:
            return None

        assert len(trials) == 1
        return TrialAdapter(trials[0], objective=self.objective)
예제 #2
0
    def register_trial(self, trial):
        """Create a new trial to be executed"""
        stamp = datetime.datetime.utcnow()
        trial.submit_time = stamp

        metadata = dict()
        # pylint: disable=protected-access
        metadata["params_types"] = {
            remove_leading_slash(p.name): p.type
            for p in trial._params
        }
        metadata["submit_time"] = to_json(trial.submit_time)
        metadata["end_time"] = to_json(trial.end_time)
        metadata["worker"] = trial.worker
        metadata["metric_types"] = {
            remove_leading_slash(p.name): p.type
            for p in trial.results
        }
        metadata["metric_types"][self.objective] = "objective"
        heartbeat = to_json(trial.heartbeat)
        if heartbeat is None:
            heartbeat = 0
        metadata["heartbeat"] = heartbeat

        metrics = defaultdict(list)
        for p in trial.results:
            metrics[p.name] = [p.value]

        if self.project is None:
            self._get_project(self.group.project_id)

        trial = self.backend.new_trial(
            TrackTrial(
                _hash=trial.hash_name,
                status=get_track_status(trial.status),
                project_id=self.project.uid,
                group_id=self.group.uid,
                parameters=trial.params,
                metadata=metadata,
                metrics=metrics,
            ),
            auto_increment=False,
        )

        if trial is None:
            raise DuplicateKeyError("Was not able to register Trial!")

        return TrialAdapter(trial, objective=self.objective)
예제 #3
0
    def get_trial(self, trial=None, uid=None):
        """See :meth:`orion.storage.base.BaseStorageProtocol.get_trial`"""
        uid = get_uid(trial, uid)

        _hash, _rev = 0, 0
        data = uid.split("_", maxsplit=1)

        if len(data) == 1:
            _hash = data[0]

        elif len(data) == 2:
            _hash, _rev = data

        trials = self.backend.get_trial(TrackTrial(_hash=_hash, revision=_rev))

        if trials is None:
            return None

        assert len(trials) == 1
        return TrialAdapter(trials[0], objective=self.objective)