示例#1
0
    def __init__(self, artifact_uri):
        super(DatabricksArtifactRepository, self).__init__(artifact_uri)
        if not artifact_uri.startswith('dbfs:/'):
            raise MlflowException(
                message=
                'DatabricksArtifactRepository URI must start with dbfs:/',
                error_code=INVALID_PARAMETER_VALUE)
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(
                message=('Artifact URI incorrect. Expected path prefix to be'
                         ' databricks/mlflow-tracking/path/to/artifact/..'),
                error_code=INVALID_PARAMETER_VALUE)
        self.run_id = self._extract_run_id(self.artifact_uri)

        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(
            run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path)
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
示例#2
0
    def _download_from_cloud(self, cloud_credential, local_file_path):
        """
        Downloads a file from the input `cloud_credential` and save it to `local_path`.

        Since the download mechanism for both cloud services, i.e., Azure and AWS is the same,
        a single download method is sufficient.

        The default working of `requests.get` is to download the entire response body immediately.
        However, this could be inefficient for large files. Hence the parameter `stream` is set to
        true. This only downloads the response headers at first and keeps the connection open,
        allowing content retrieval to be made via `iter_content`.
        In addition, since the connection is kept open, refreshing credentials is not required.
        """
        if cloud_credential.type not in [
                ArtifactCredentialType.AZURE_SAS_URI,
                ArtifactCredentialType.AWS_PRESIGNED_URL
        ]:
            raise MlflowException(message='Cloud provider not supported.',
                                  error_code=INTERNAL_ERROR)
        try:
            signed_read_uri = cloud_credential.signed_uri
            with requests.get(signed_read_uri, stream=True) as response:
                response.raise_for_status()
                with open(local_file_path, "wb") as output_file:
                    for chunk in response.iter_content(
                            chunk_size=_DOWNLOAD_CHUNK_SIZE):
                        if not chunk:
                            break
                        output_file.write(chunk)
        except Exception as err:
            raise MlflowException(err)
示例#3
0
def test_rest_exception():
    mlflow_exception = MlflowException('test',
                                       error_code=RESOURCE_ALREADY_EXISTS)
    json_exception = mlflow_exception.serialize_as_json()
    deserialized_rest_exception = RestException(json.loads(json_exception))
    assert deserialized_rest_exception.error_code == "RESOURCE_ALREADY_EXISTS"
    assert "test" in deserialized_rest_exception.message
示例#4
0
def _get_flavor_configuration_from_uri(model_uri, flavor_name):
    """
    Obtains the configuration for the specified flavor from the specified
    MLflow model uri. If the model does not contain the specified flavor,
    an exception will be thrown.

    :param model_uri: The path to the root directory of the MLflow model for which to load
                       the specified flavor configuration.
    :param flavor_name: The name of the flavor configuration to load.
    :return: The flavor configuration as a dictionary.
    """
    try:
        ml_model_file = _download_artifact_from_uri(
            artifact_uri=append_to_uri_path(model_uri, MLMODEL_FILE_NAME))
    except Exception as ex:
        raise MlflowException(
            "Failed to download an \"{model_file}\" model file from \"{model_uri}\": {ex}"
            .format(model_file=MLMODEL_FILE_NAME, model_uri=model_uri,
                    ex=ex), RESOURCE_DOES_NOT_EXIST)
    model_conf = Model.load(ml_model_file)
    if flavor_name not in model_conf.flavors:
        raise MlflowException(
            "Model does not have the \"{flavor_name}\" flavor".format(
                flavor_name=flavor_name), RESOURCE_DOES_NOT_EXIST)
    return model_conf.flavors[flavor_name]
示例#5
0
    def create_experiment(self, name, artifact_location=None):
        if name is None or name == '':
            raise MlflowException('Invalid experiment name',
                                  INVALID_PARAMETER_VALUE)

        with self.ManagedSessionMaker() as session:
            try:
                experiment = SqlExperiment(
                    name=name,
                    lifecycle_stage=LifecycleStage.ACTIVE,
                    artifact_location=artifact_location)
                session.add(experiment)
                if not artifact_location:
                    # this requires a double write. The first one to generate an autoincrement-ed ID
                    eid = session.query(SqlExperiment).filter_by(
                        name=name).first().experiment_id
                    experiment.artifact_location = self._get_artifact_location(
                        eid)
            except sqlalchemy.exc.IntegrityError as e:
                raise MlflowException(
                    'Experiment(name={}) already exists. '
                    'Error: {}'.format(name, str(e)), RESOURCE_ALREADY_EXISTS)

            session.flush()
            return str(experiment.experiment_id)
示例#6
0
 def __init__(self,
              host,
              username=None,
              password=None,
              token=None,
              ignore_tls_verification=False,
              client_cert_path=None,
              server_cert_path=None):
     if not host:
         raise MlflowException(
             message="host is a required parameter for MlflowHostCreds",
             error_code=INVALID_PARAMETER_VALUE,
         )
     if ignore_tls_verification and (server_cert_path is not None):
         raise MlflowException(
             message=
             ("When 'ignore_tls_verification' is true then 'server_cert_path' "
              "must not be set! This error may have occurred because the "
              "'MLFLOW_TRACKING_INSECURE_TLS' and 'MLFLOW_TRACKING_SERVER_CERT_PATH' "
              "environment variables are both set - only one of these environment "
              "variables may be set."),
             error_code=INVALID_PARAMETER_VALUE,
         )
     self.host = host
     self.username = username
     self.password = password
     self.token = token
     self.ignore_tls_verification = ignore_tls_verification
     self.client_cert_path = client_cert_path
     self.server_cert_path = server_cert_path
示例#7
0
    def _get_registered_model(cls, session, name, eager=False):
        """
        :param eager: If ``True``, eagerly loads the registered model's tags.
                      If ``False``, these attributes are not eagerly loaded and
                      will be loaded when their corresponding object properties
                      are accessed from the resulting ``SqlRegisteredModel`` object.
        """
        _validate_model_name(name)
        query_options = cls._get_eager_registered_model_query_options(
        ) if eager else []
        rms = session \
            .query(SqlRegisteredModel) \
            .options(*query_options) \
            .filter(SqlRegisteredModel.name == name) \
            .all()

        if len(rms) == 0:
            raise MlflowException(
                'Registered Model with name={} not found'.format(name),
                RESOURCE_DOES_NOT_EXIST)
        if len(rms) > 1:
            raise MlflowException(
                'Expected only 1 registered model with name={}. '
                'Found {}.'.format(name, len(rms)), INVALID_STATE)
        return rms[0]
示例#8
0
    def parse_runs_uri(run_uri):
        parsed = urllib.parse.urlparse(run_uri)
        if parsed.scheme != "runs":
            raise MlflowException(
                "Not a proper runs:/ URI: %s. " % run_uri +
                "Runs URIs must be of the form 'runs:/<run_id>/run-relative/path/to/artifact'")
        # hostname = parsed.netloc  # TODO: support later

        path = parsed.path
        if not path.startswith('/') or len(path) <= 1:
            raise MlflowException(
                "Not a proper runs:/ URI: %s. " % run_uri +
                "Runs URIs must be of the form 'runs:/<run_id>/run-relative/path/to/artifact'")
        path = path[1:]

        path_parts = path.split('/')
        run_id = path_parts[0]
        if run_id == '':
            raise MlflowException(
                "Not a proper runs:/ URI: %s. " % run_uri +
                "Runs URIs must be of the form 'runs:/<run_id>/run-relative/path/to/artifact'")

        artifact_path = '/'.join(path_parts[1:]) if len(path_parts) > 1 else None
        artifact_path = artifact_path if artifact_path != '' else None

        return run_id, artifact_path
示例#9
0
def gc(backend_store_uri, run_ids):
    """
    Permanently delete runs in the `deleted` lifecycle stage from the specified backend store.
    This command deletes all artifacts and metadata associated with the specified runs.
    """
    backend_store = _get_store(backend_store_uri, None)
    if not hasattr(backend_store, '_hard_delete_run'):
        raise MlflowException(
            "This cli can only be used with a backend that allows hard-deleting runs"
        )
    if not run_ids:
        run_ids = backend_store._get_deleted_runs()
    else:
        run_ids = run_ids.split(',')

    for run_id in run_ids:
        run = backend_store.get_run(run_id)
        if run.info.lifecycle_stage != LifecycleStage.DELETED:
            raise MlflowException(
                'Run {} is not in `deleted` lifecycle stage. Only runs in '
                '`deleted` lifecycle stage can be deleted.'.format(run_id))
        artifact_repo = get_artifact_repository(run.info.artifact_uri)
        artifact_repo.delete_artifacts()
        backend_store._hard_delete_run(run_id)
        print("Run with ID %s has been permanently deleted." % str(run_id))
示例#10
0
def _get_flavor_configuration(model_path, flavor_name):
    """
    Obtains the configuration for the specified flavor from the specified
    MLflow model path. If the model does not contain the specified flavor,
    an exception will be thrown.

    :param model_path: The path to the root directory of the MLflow model for which to load
                       the specified flavor configuration.
    :param flavor_name: The name of the flavor configuration to load.
    :return: The flavor configuration as a dictionary.
    """
    model_configuration_path = os.path.join(model_path, MLMODEL_FILE_NAME)
    if not os.path.exists(model_configuration_path):
        raise MlflowException(
            "Could not find an \"{model_file}\" configuration file at \"{model_path}\""
            .format(model_file=MLMODEL_FILE_NAME,
                    model_path=model_path), RESOURCE_DOES_NOT_EXIST)

    model_conf = Model.load(model_configuration_path)
    if flavor_name not in model_conf.flavors:
        raise MlflowException(
            "Model does not have the \"{flavor_name}\" flavor".format(
                flavor_name=flavor_name), RESOURCE_DOES_NOT_EXIST)
    conf = model_conf.flavors[flavor_name]
    return conf
示例#11
0
    def _get_sql_model_version(cls, session, name, version, eager=False):
        """
        :param eager: If ``True``, eagerly loads the model version's tags.
                      If ``False``, these attributes are not eagerly loaded and
                      will be loaded when their corresponding object properties
                      are accessed from the resulting ``SqlModelVersion`` object.
        """
        _validate_model_name(name)
        _validate_model_version(version)
        query_options = cls._get_eager_model_version_query_options(
        ) if eager else []
        conditions = [
            SqlModelVersion.name == name, SqlModelVersion.version == version,
            SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL
        ]
        versions = session.query(SqlModelVersion).options(
            *query_options).filter(*conditions).all()

        if len(versions) == 0:
            raise MlflowException(
                'Model Version (name={}, version={}) '
                'not found'.format(name, version), RESOURCE_DOES_NOT_EXIST)
        if len(versions) > 1:
            raise MlflowException(
                'Expected only 1 model version with (name={}, version={}). '
                'Found {}.'.format(name, version, len(versions)),
                INVALID_STATE)
        return versions[0]
示例#12
0
文件: __init__.py 项目: iPieter/kiwi
 def __init__(self, model_meta: Model, model_impl: Any):
     if not hasattr(model_impl, "predict"):
         raise MlflowException(
             "Model implementation is missing required predict method.")
     if not model_meta:
         raise MlflowException("Model is missing metadata.")
     self._model_meta = model_meta
     self._model_impl = model_impl
示例#13
0
def _validate_metric_name(name):
    """Check that `name` is a valid metric name and raise an exception if it isn't."""
    if name is None or not _VALID_PARAM_AND_METRIC_NAMES.match(name):
        raise MlflowException(
            "Invalid metric name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE),
            INVALID_PARAMETER_VALUE)
    if path_not_unique(name):
        raise MlflowException(
            "Invalid metric name: '%s'. %s" % (name, bad_path_message(name)),
            INVALID_PARAMETER_VALUE)
示例#14
0
def _validate_experiment_name(experiment_name):
    """Check that `experiment_name` is a valid string and raise an exception if it isn't."""
    if experiment_name == "" or experiment_name is None:
        raise MlflowException("Invalid experiment name: '%s'" %
                              experiment_name,
                              error_code=INVALID_PARAMETER_VALUE)

    if not is_string_type(experiment_name):
        raise MlflowException(
            "Invalid experiment name: %s. Expects a string." % experiment_name,
            error_code=INVALID_PARAMETER_VALUE)
示例#15
0
    def matches_view_type(cls, view_type, lifecycle_stage):
        if not cls.is_valid(lifecycle_stage):
            raise MlflowException("Invalid lifecycle stage '%s'" % str(lifecycle_stage))

        if view_type == ViewType.ALL:
            return True
        elif view_type == ViewType.ACTIVE_ONLY:
            return lifecycle_stage == LifecycleStage.ACTIVE
        elif view_type == ViewType.DELETED_ONLY:
            return lifecycle_stage == LifecycleStage.DELETED
        else:
            raise MlflowException("Invalid view type '%s'" % str(view_type))
示例#16
0
文件: __init__.py 项目: iPieter/kiwi
def _enforce_type(name, values: pandas.Series, t: DataType):
    """
    Enforce the input column type matches the declared in model input schema.

    The following type conversions are allowed:

    1. np.object -> string
    2. int -> long (upcast)
    3. float -> double (upcast)

    Any other type mismatch will raise error.
    """
    if values.dtype == np.object and t not in (DataType.binary,
                                               DataType.string):
        values = values.infer_objects()

    if values.dtype in (t.to_pandas(), t.to_numpy()):
        # The types are already compatible => conversion is not necessary.
        return values

    if t == DataType.binary and values.dtype.kind == t.binary.to_numpy().kind:
        # NB: bytes in numpy have variable itemsize depending on the length of the longest
        # element in the array (column). Since MLflow binary type is length agnostic, we ignore
        # itemsize when matching binary columns.
        return values

    if t == DataType.string and values.dtype == np.object:
        #  NB: strings are by default parsed and inferred as objects, but it is
        # recommended to use StringDtype extension type if available. See
        #
        # `https://pandas.pydata.org/pandas-docs/stable/user_guide/text.html`
        #
        # for more detail.
        try:
            return values.astype(t.to_pandas(), errors="raise")
        except ValueError:
            raise MlflowException(
                "Failed to convert column {0} from type {1} to {2}.".format(
                    name, values.dtype, t))

    numpy_type = t.to_numpy()
    is_compatible_type = values.dtype.kind == numpy_type.kind
    is_upcast = values.dtype.itemsize <= numpy_type.itemsize
    if is_compatible_type and is_upcast:
        return values.astype(numpy_type, errors="raise")
    else:
        # NB: conversion between incompatible types (e.g. floats -> ints or
        # double -> float) are not allowed. While supported by pandas and numpy,
        # these conversions alter the values significantly.
        raise MlflowException("Incompatible input types for column {0}. "
                              "Can not safely convert {1} to {2}.".format(
                                  name, values.dtype, numpy_type))
示例#17
0
    def __getitem__(self, item):
        """Override __getitem__ so that we can directly look up plugins via dict-like syntax"""
        try:
            target_name = parse_target_uri(item)
            plugin_like = self.registry[target_name]
        except KeyError:
            msg = 'No plugin found for managing model deployments to "{target}". ' \
                  'In order to deploy models to "{target}", find and install an appropriate ' \
                  'plugin from ' \
                  'https://mlflow.org/docs/latest/plugins.html#community-plugins using ' \
                  'your package manager (pip, conda etc).'.format(target=item)
            raise MlflowException(msg, error_code=RESOURCE_DOES_NOT_EXIST)

        if isinstance(plugin_like, entrypoints.EntryPoint):
            try:
                plugin_obj = plugin_like.load()
            except (AttributeError, ImportError) as exc:
                raise RuntimeError('Failed to load the plugin "{}": {}'.format(
                    item, str(exc)))
            self.registry[item] = plugin_obj
        else:
            plugin_obj = plugin_like

        # Testing whether the plugin is valid or not
        expected = {'target_help', 'run_local'}
        deployment_classes = []
        for name, obj in inspect.getmembers(plugin_obj):
            if name in expected:
                expected.remove(name)
            elif inspect.isclass(obj) and \
                    issubclass(obj, BaseDeploymentClient) and \
                    not obj == BaseDeploymentClient:
                deployment_classes.append(name)
        if len(expected) > 0:
            raise MlflowException(
                "Plugin registered for the target {} does not has all "
                "the required interfaces. Raise an issue with the "
                "plugin developers.\n"
                "Missing interfaces: {}".format(item, expected),
                error_code=INTERNAL_ERROR)
        if len(deployment_classes) > 1:
            raise MlflowException(
                "Plugin registered for the target {} has more than one "
                "child class of BaseDeploymentClient. Raise an issue with"
                " the plugin developers. "
                "Classes found are {}".format(item, deployment_classes))
        elif len(deployment_classes) == 0:
            raise MlflowException(
                "Plugin registered for the target {} has no child class"
                " of BaseDeploymentClient. Raise an issue with the "
                "plugin developers".format(item))
        return plugin_obj
示例#18
0
 def _get_run_info(self, run_uuid):
     """
     Note: Will get both active and deleted runs.
     """
     exp_id, run_dir = self._find_run_root(run_uuid)
     if run_dir is None:
         raise MlflowException("Run '%s' not found" % run_uuid,
                               databricks_pb2.RESOURCE_DOES_NOT_EXIST)
     run_info = self._get_run_info_from_dir(run_dir)
     if run_info.experiment_id != exp_id:
         raise MlflowException(
             "Run '%s' metadata is in invalid state." % run_uuid,
             databricks_pb2.INVALID_STATE)
     return run_info
示例#19
0
文件: __init__.py 项目: iPieter/kiwi
def _enforce_schema(pdf: pandas.DataFrame, input_schema: Schema):
    """
    Enforce column names and types match the input schema.

    For column names, we check there are no missing columns and reorder the columns to match the
    ordering declared in schema if necessary. Any extra columns are ignored.

    For column types, we make sure the types match schema or can be safely converted to match the
    input schema.
    """
    if isinstance(pdf, list):
        pdf = pandas.DataFrame(pdf)
    if not isinstance(pdf, pandas.DataFrame):
        message = 'Expected input to be DataFrame or list. Found: %s' % type(
            pdf).__name__
        raise MlflowException(message)

    if input_schema.has_column_names():
        # make sure there are no missing columns
        col_names = input_schema.column_names()
        expected_names = set(col_names)
        actual_names = set(pdf.columns)
        missing_cols = expected_names - actual_names
        extra_cols = actual_names - expected_names
        # Preserve order from the original columns, since missing/extra columns are likely to
        # be in same order.
        missing_cols = [c for c in col_names if c in missing_cols]
        extra_cols = [c for c in pdf.columns if c in extra_cols]
        if missing_cols:
            message = ("Model input is missing columns {0}."
                       " Note that there were extra columns: {1}".format(
                           missing_cols, extra_cols))
            raise MlflowException(message)
    else:
        # The model signature does not specify column names => we can only verify column count.
        if len(pdf.columns) < len(input_schema.columns):
            message = (
                "Model input is missing input columns. The model signature declares "
                "{0} input columns but the provided input only has "
                "{1} columns. Note: the columns were not named in the signature so we can "
                "only verify their count.").format(len(input_schema.columns),
                                                   len(pdf.columns))
            raise MlflowException(message)
        col_names = pdf.columns[:len(input_schema.columns)]
    col_types = input_schema.column_types()
    new_pdf = pandas.DataFrame()
    for i, x in enumerate(col_names):
        new_pdf[x] = _enforce_type(x, pdf[x], col_types[i])
    return new_pdf
示例#20
0
文件: utils.py 项目: iPieter/kiwi
def parse_target_uri(target_uri):
    """Parse out the deployment target from the provided target uri"""
    parsed = urllib.parse.urlparse(target_uri)
    if not parsed.scheme:
        if parsed.path:
            # uri = 'target_name' (without :/<path>)
            return parsed.path
        raise MlflowException(
            "Not a proper deployment URI: %s. " % target_uri +
            "Deployment URIs must be of the form 'target' or 'target:/suffix'")
    if parsed.netloc:  # Handle e.g. target_name://suffix, where 'suffix' gets parsed as netloc
        raise MlflowException(
            "Not a proper deployment URI: %s. " % target_uri +
            "Deployment URIs must be of the form 'target:/suffix'")
    return parsed.scheme
示例#21
0
 def restore_experiment(self, experiment_id):
     experiment_dir = self._get_experiment_path(experiment_id,
                                                ViewType.DELETED_ONLY)
     if experiment_dir is None:
         raise MlflowException(
             "Could not find deleted experiment with ID %d" % experiment_id,
             databricks_pb2.RESOURCE_DOES_NOT_EXIST)
     conflict_experiment = self._get_experiment_path(
         experiment_id, ViewType.ACTIVE_ONLY)
     if conflict_experiment is not None:
         raise MlflowException(
             "Cannot restore eperiment with ID %d. "
             "An experiment with same ID already exists." % experiment_id,
             databricks_pb2.RESOURCE_ALREADY_EXISTS)
     mv(experiment_dir, self.root_directory)
示例#22
0
    def search_model_versions(self, filter_string):
        """
        Search for model versions in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'`` or
                              ``run_id = '...'``.
        :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 objects.
        """
        parsed_filter = SearchUtils.parse_filter_for_model_versions(
            filter_string)
        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] != "=":
                raise MlflowException(
                    'Model Registry search filter only supports equality(=) '
                    'comparator. Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if filter_dict["key"] == "name":
                conditions = [SqlModelVersion.name == filter_dict["value"]]
            elif filter_dict["key"] == "source_path":
                conditions = [SqlModelVersion.source == filter_dict["value"]]
            elif filter_dict["key"] == "run_id":
                conditions = [SqlModelVersion.run_id == filter_dict["value"]]
            else:
                raise MlflowException('Invalid filter string: %s' %
                                      filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            raise MlflowException(
                'Model Registry expects filter to be one of '
                '"name = \'<model_name>\'" or '
                '"source_path = \'<source_path>\'" or "run_id = \'<run_id>\'.'
                'Input filter string: %s. ' % filter_string,
                error_code=INVALID_PARAMETER_VALUE)

        with self.ManagedSessionMaker() as session:
            conditions.append(
                SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL)
            sql_model_version = session.query(SqlModelVersion).filter(
                *conditions).all()
            model_versions = [
                mv.to_mlflow_entity() for mv in sql_model_version
            ]
            return PagedList(model_versions, None)
示例#23
0
def _load_model(path, **kwargs):
    """
    :param path: The path to a serialized PyTorch model.
    :param kwargs: Additional kwargs to pass to the PyTorch ``torch.load`` function.
    """
    import torch

    if os.path.isdir(path):
        # `path` is a directory containing a serialized PyTorch model and a text file containing
        # information about the pickle module that should be used by PyTorch to load it
        model_path = os.path.join(path, "model.pth")
        pickle_module_path = os.path.join(path, _PICKLE_MODULE_INFO_FILE_NAME)
        with open(pickle_module_path, "r") as f:
            pickle_module_name = f.read()
        if "pickle_module" in kwargs and kwargs[
                "pickle_module"].__name__ != pickle_module_name:
            _logger.warning(
                "Attempting to load the PyTorch model with a pickle module, '%s', that does not"
                " match the pickle module that was used to save the model: '%s'.",
                kwargs["pickle_module"].__name__, pickle_module_name)
        else:
            try:
                kwargs["pickle_module"] = importlib.import_module(
                    pickle_module_name)
            except ImportError:
                raise MlflowException(message=(
                    "Failed to import the pickle module that was used to save the PyTorch"
                    " model. Pickle module name: `{pickle_module_name}`".
                    format(pickle_module_name=pickle_module_name)),
                                      error_code=RESOURCE_DOES_NOT_EXIST)

    else:
        model_path = path

    return torch.load(model_path, **kwargs)
示例#24
0
文件: utils.py 项目: iPieter/kiwi
def _infer_numpy_dtype(dtype: np.dtype) -> DataType:
    if not isinstance(dtype, np.dtype):
        raise TypeError("Expected numpy.dtype, got '{}'.".format(type(dtype)))
    if dtype.kind == "b":
        return DataType.boolean
    elif dtype.kind == "i" or dtype.kind == "u":
        if dtype.itemsize < 4 or (dtype.kind == "i" and dtype.itemsize == 4):
            return DataType.integer
        elif dtype.itemsize < 8 or (dtype.kind == "i" and dtype.itemsize == 8):
            return DataType.long
    elif dtype.kind == "f":
        if dtype.itemsize <= 4:
            return DataType.float
        elif dtype.itemsize <= 8:
            return DataType.double

    elif dtype.kind == "U":
        return DataType.string
    elif dtype.kind == "S":
        return DataType.binary
    elif dtype.kind == "O":
        raise Exception(
            "Can not infer np.object without looking at the values, call "
            "_map_numpy_array instead.")
    raise MlflowException(
        "Unsupported numpy data type '{0}', kind '{1}'".format(
            dtype, dtype.kind))
示例#25
0
文件: __init__.py 项目: iPieter/kiwi
def _create_dockerfile(output_path, mlflow_path=None):
    """
    Creates a Dockerfile containing additional Docker build steps to execute
    when building the Azure container image. These build steps perform the following tasks:

    - Install MLflow

    :param output_path: The path where the Dockerfile will be written.
    :param mlflow_path: Path to a local copy of the MLflow GitHub repository. If specified, the
                        Dockerfile command for MLflow installation will install MLflow from this
                        directory. Otherwise, it will install MLflow from pip.
    """
    docker_cmds = ["RUN apt-get update && apt-get install -y default-jre"]
    docker_cmds.append("RUN pip install azureml-sdk")

    if mlflow_path is not None:
        mlflow_install_cmd = "RUN pip install -e {mlflow_path}".format(
            mlflow_path=_get_container_path(mlflow_path))
    elif not mlflow_version.endswith("dev"):
        mlflow_install_cmd = "RUN pip install mlflow=={mlflow_version}".format(
            mlflow_version=mlflow_version)
    else:
        raise MlflowException(
            "You are running a 'dev' version of MLflow: `{mlflow_version}` that cannot be"
            " installed from pip. In order to build a container image, either specify the"
            " path to a local copy of the MLflow GitHub repository using the `mlflow_home`"
            " parameter or install a release version of MLflow from pip".format(
                mlflow_version=mlflow_version))
    docker_cmds.append(mlflow_install_cmd)

    with open(output_path, "w") as f:
        f.write("\n".join(docker_cmds))
示例#26
0
 def list_artifacts(self, path=None):
     if path:
         dbfs_path = self._get_dbfs_path(path)
     else:
         dbfs_path = self._get_dbfs_path('')
     dbfs_list_json = {'path': dbfs_path}
     response = self._dbfs_list_api(dbfs_list_json)
     try:
         json_response = json.loads(response.text)
     except ValueError:
         raise MlflowException(
             "API request to list files under DBFS path %s failed with status code %s. "
             "Response body: %s" %
             (dbfs_path, response.status_code, response.text))
     # /api/2.0/dbfs/list will not have the 'files' key in the response for empty directories.
     infos = []
     artifact_prefix = strip_prefix(self.artifact_uri, 'dbfs:')
     if json_response.get('error_code', None) == RESOURCE_DOES_NOT_EXIST:
         return []
     dbfs_files = json_response.get('files', [])
     for dbfs_file in dbfs_files:
         stripped_path = strip_prefix(dbfs_file['path'],
                                      artifact_prefix + '/')
         # If `path` is a file, the DBFS list API returns a single list element with the
         # same name as `path`. The list_artifacts API expects us to return an empty list in this
         # case, so we do so here.
         if stripped_path == path:
             return []
         is_dir = dbfs_file['is_dir']
         artifact_size = None if is_dir else dbfs_file['file_size']
         infos.append(FileInfo(stripped_path, is_dir, artifact_size))
     return sorted(infos, key=lambda f: f.path)
示例#27
0
def _validate_model_version(model_version):
    try:
        model_version = int(model_version)
    except ValueError:
        raise MlflowException(
            "Model version must be an integer, got '{}'".format(model_version),
            error_code=INVALID_PARAMETER_VALUE)
示例#28
0
def dbfs_artifact_repo_factory(artifact_uri):
    """
    Returns an ArtifactRepository subclass for storing artifacts on DBFS.

    This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
    storage can only be used together with the RestStore.

    In the special case where the URI is of the form
    `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>',
    a DatabricksArtifactRepository is returned. This is capable of storing access controlled
    artifacts.

    :param artifact_uri: DBFS root artifact URI (string).
    :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS.
    """
    cleaned_artifact_uri = artifact_uri.rstrip('/')
    uri_scheme = get_uri_scheme(artifact_uri)
    if uri_scheme != 'dbfs':
        raise MlflowException(
            "DBFS URI must be of the form "
            "dbfs:/<path>, but received {uri}".format(uri=artifact_uri))
    if is_databricks_acled_artifacts_uri(artifact_uri):
        return DatabricksArtifactRepository(cleaned_artifact_uri)
    elif kiwi.utils.databricks_utils.is_dbfs_fuse_available() \
            and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" \
            and not artifact_uri.startswith("dbfs:/databricks/mlflow-registry"):
        # If the DBFS FUSE mount is available, write artifacts directly to /dbfs/... using
        # local filesystem APIs
        file_uri = "file:///dbfs/{}".format(
            strip_prefix(cleaned_artifact_uri, "dbfs:/"))
        return LocalArtifactRepository(file_uri)
    return DbfsRestArtifactRepository(cleaned_artifact_uri)
示例#29
0
def _validate_batch_log_api_req(json_req):
    if len(json_req) > MAX_BATCH_LOG_REQUEST_SIZE:
        error_msg = (
            "Batched logging API requests must be at most {limit} bytes, got a "
            "request of size {size}.").format(limit=MAX_BATCH_LOG_REQUEST_SIZE,
                                              size=len(json_req))
        raise MlflowException(error_msg, error_code=INVALID_PARAMETER_VALUE)
示例#30
0
def _validate_batch_limit(entity_name, limit, length):
    if length > limit:
        error_msg = (
            "A batch logging request can contain at most {limit} {name}. "
            "Got {count} {name}. Please split up {name} across multiple requests and try "
            "again.").format(name=entity_name, count=length, limit=limit)
        raise MlflowException(error_msg, error_code=INVALID_PARAMETER_VALUE)