Esempio n. 1
0
def test_client_create_run_overrides(mock_store):
    experiment_id = mock.Mock()
    user = mock.Mock()
    start_time = mock.Mock()
    tags = {
        MLFLOW_USER: user,
        MLFLOW_PARENT_RUN_ID: mock.Mock(),
        MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB),
        MLFLOW_SOURCE_NAME: mock.Mock(),
        MLFLOW_PROJECT_ENTRY_POINT: mock.Mock(),
        MLFLOW_GIT_COMMIT: mock.Mock(),
        "other-key": "other-value"
    }

    MlflowClient().create_run(experiment_id, start_time, tags)

    mock_store.create_run.assert_called_once_with(
        experiment_id=experiment_id,
        user_id=user,
        start_time=start_time,
        tags=[RunTag(key, value) for key, value in tags.items()],
    )
    mock_store.reset_mock()
    MlflowClient().create_run(experiment_id, start_time, tags)
    mock_store.create_run.assert_called_once_with(
        experiment_id=experiment_id,
        user_id=user,
        start_time=start_time,
        tags=[RunTag(key, value) for key, value in tags.items()])
Esempio n. 2
0
 def get_underlying_uri(uri):
     # Note: to support a registry URI that is different from the tracking URI here,
     # we'll need to add setting of registry URIs via environment variables.
     from kiwi.tracking import MlflowClient
     client = MlflowClient()
     (name, version, stage) = ModelsArtifactRepository._parse_uri(uri)
     if stage is not None:
         latest = client.get_latest_versions(name, [stage])
         if len(latest) == 0:
             raise MlflowException("No versions of model with name '{name}' and "
                                   "stage '{stage}' found".format(name=name, stage=stage))
         version = latest[0].version
     return client.get_model_version_download_uri(name, version)
Esempio n. 3
0
def test_registry_uri_from_tracking_uri_param():
    with mock.patch("mlflow.tracking._tracking_service.utils.get_tracking_uri") \
            as get_tracking_uri_mock:
        get_tracking_uri_mock.return_value = "databricks://default_tracking"
        tracking_uri = "databricks://tracking_vhawoierj"
        client = MlflowClient(tracking_uri=tracking_uri)
        assert client._registry_uri == tracking_uri
Esempio n. 4
0
def test_registry_uri_from_implicit_tracking_uri():
    with mock.patch("mlflow.tracking._tracking_service.utils.get_tracking_uri")\
            as get_tracking_uri_mock:
        tracking_uri = "databricks://tracking_wierojasdf"
        get_tracking_uri_mock.return_value = tracking_uri
        client = MlflowClient()
        assert client._registry_uri == tracking_uri
Esempio n. 5
0
def test_client_search_runs_max_results(mock_store):
    MlflowClient().search_runs([5], "my filter", ViewType.ALL, 2876)
    mock_store.search_runs.assert_called_once_with(experiment_ids=[5],
                                                   filter_string="my filter",
                                                   run_view_type=ViewType.ALL,
                                                   max_results=2876,
                                                   order_by=None,
                                                   page_token=None)
Esempio n. 6
0
def test_client_search_runs_filter(mock_store):
    MlflowClient().search_runs(["a", "b", "c"], "my filter")
    mock_store.search_runs.assert_called_once_with(
        experiment_ids=["a", "b", "c"],
        filter_string="my filter",
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=SEARCH_MAX_RESULTS_DEFAULT,
        order_by=None,
        page_token=None)
Esempio n. 7
0
def test_client_create_run(mock_store, mock_time):
    experiment_id = mock.Mock()

    MlflowClient().create_run(experiment_id)

    mock_store.create_run.assert_called_once_with(experiment_id=experiment_id,
                                                  user_id="unknown",
                                                  start_time=int(mock_time *
                                                                 1000),
                                                  tags=[])
Esempio n. 8
0
def register_model(model_uri, name):
    """
    Create a new model version in model registry for the model files specified by ``model_uri``.
    Note that this method assumes the model registry backend URI is the same as that of the
    tracking backend.

    :param model_uri: URI referring to the MLmodel directory. Use a ``runs:/`` URI if you want to
                      record the run ID with the model in model registry. ``models:/`` URIs are
                      currently not supported.
    :param name: Name of the registered model under which to create a new model version. If a
                 registered model with the given name does not exist, it will be created
                 automatically.
    :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
             backend.
    """
    client = MlflowClient()
    try:
        create_model_response = client.create_registered_model(name)
        eprint("Successfully registered model '%s'." %
               create_model_response.name)
    except MlflowException as e:
        if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
            eprint(
                "Registered model '%s' already exists. Creating a new version of this model..."
                % name)
        else:
            raise e

    if RunsArtifactRepository.is_runs_uri(model_uri):
        source = RunsArtifactRepository.get_underlying_uri(model_uri)
        (run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri)
        create_version_response = client.create_model_version(
            name, source, run_id)
    else:
        create_version_response = client.create_model_version(name,
                                                              source=model_uri,
                                                              run_id=None)
    eprint("Created version '{version}' of model '{model_name}'.".format(
        version=create_version_response.version,
        model_name=create_version_response.name))
    return create_version_response
Esempio n. 9
0
def test_update_model_version(mock_registry_store):
    """
    Update registered model no longer support state changes.
    """
    expected_return_value = "some expected return value."
    mock_registry_store.update_model_version.return_value = expected_return_value
    res = MlflowClient(
        registry_uri="sqlite:///somedb.db").update_model_version(
            name="orig name", version="1", description="desc")
    assert expected_return_value == res
    mock_registry_store.update_model_version.assert_called_once_with(
        name="orig name", version="1", description="desc")
    mock_registry_store.transition_model_version_stage.assert_not_called()
Esempio n. 10
0
def test_update_registered_model(mock_registry_store):
    """
    Update registered model no longer supports name change.
    """
    expected_return_value = "some expected return value."
    mock_registry_store.rename_registered_model.return_value = expected_return_value
    expected_return_value_2 = "other expected return value."
    mock_registry_store.update_registered_model.return_value = expected_return_value_2
    res = MlflowClient(
        registry_uri="sqlite:///somedb.db").update_registered_model(
            name="orig name", description="new description")
    assert expected_return_value_2 == res
    mock_registry_store.update_registered_model.assert_called_once_with(
        name="orig name", description="new description")
    mock_registry_store.rename_registered_model.assert_not_called()
Esempio n. 11
0
def test_transition_model_version_stage(mock_registry_store):
    name = "Model 1"
    version = "12"
    stage = "Production"
    expected_result = ModelVersion(name,
                                   version,
                                   creation_timestamp=123,
                                   current_stage=stage)
    mock_registry_store.transition_model_version_stage.return_value = expected_result
    actual_result = (MlflowClient(
        registry_uri="sqlite:///somedb.db").transition_model_version_stage(
            name, version, stage))
    mock_registry_store.transition_model_version_stage.assert_called_once_with(
        name=name,
        version=version,
        stage=stage,
        archive_existing_versions=False)
    assert expected_result == actual_result
Esempio n. 12
0
def test_client_registry_operations_raise_exception_with_unsupported_registry_store(
):
    """
    This test case ensures that Model Registry operations invoked on the `MlflowClient`
    fail with an informative error message when the registry store URI refers to a
    store that does not support Model Registry features (e.g., FileStore).
    """
    with TempDir() as tmp:
        client = MlflowClient(registry_uri=tmp.path())
        expected_failure_functions = [
            client._get_registry_client,
            lambda: client.create_registered_model("test"),
            lambda: client.get_registered_model("test"),
            lambda: client.create_model_version("test", "source", "run_id"),
            lambda: client.get_model_version("test", 1),
        ]
        for func in expected_failure_functions:
            with pytest.raises(MlflowException) as exc:
                func()
            assert exc.value.error_code == ErrorCode.Name(FEATURE_DISABLED)
Esempio n. 13
0
def set_tag_mock():
    with mock.patch("mlflow.projects.databricks.tracking.MlflowClient") as m:
        mlflow_service_mock = mock.Mock(wraps=MlflowClient())
        m.return_value = mlflow_service_mock
        yield mlflow_service_mock.set_tag
Esempio n. 14
0

def rand_str(max_len=40):
    return "".join(
        random.sample(string.ascii_letters, random.randint(1, max_len)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--large",
        help=
        "If true, will also generate larger datasets for testing UI performance.",
        action="store_true")
    args = parser.parse_args()
    client = MlflowClient()
    # Simple run
    for l1, alpha in itertools.product([0, 0.25, 0.5, 0.75, 1], [0, 0.5, 1]):
        with kiwi.start_run(run_name='ipython'):
            parameters = {
                'l1': str(l1),
                'alpha': str(alpha),
            }
            metrics = {
                'MAE': [rand()],
                'R2': [rand()],
                'RMSE': [rand()],
            }
            log_params(parameters)
            log_metrics(metrics)
Esempio n. 15
0
def mlflow_client(tracking_server_uri):
    """Provides an MLflow Tracking API client pointed at the local tracking server."""
    kiwi.set_tracking_uri(tracking_server_uri)
    yield mock.Mock(wraps=MlflowClient(tracking_server_uri))
    kiwi.set_tracking_uri(None)
Esempio n. 16
0
def test_registry_uri_from_set_registry_uri():
    uri = "sqlite:///somedb.db"
    set_registry_uri(uri)
    client = MlflowClient(tracking_uri="databricks://tracking")
    assert client._registry_uri == uri
    set_registry_uri(None)
Esempio n. 17
0
def test_registry_uri_set_as_param():
    uri = "sqlite:///somedb.db"
    client = MlflowClient(tracking_uri="databricks://tracking",
                          registry_uri=uri)
    assert client._registry_uri == uri