Ejemplo n.º 1
0
def test_log_batch_api_req(mock_get_request_json):
    mock_get_request_json.return_value = "a" * (MAX_BATCH_LOG_REQUEST_SIZE + 1)
    response = _log_batch()
    assert response.status_code == 400
    json_response = json.loads(response.get_data())
    assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
    assert ("Batched logging API requests must be at most %s bytes" % MAX_BATCH_LOG_REQUEST_SIZE
            in json_response["message"])
Ejemplo n.º 2
0
 def test_log_batch_nonexistent_run(self):
     fs = FileStore(self.test_root)
     nonexistent_uuid = uuid.uuid4().hex
     with self.assertRaises(MlflowException) as e:
         fs.log_batch(nonexistent_uuid, [], [], [])
     assert e.exception.error_code == ErrorCode.Name(
         RESOURCE_DOES_NOT_EXIST)
     assert ("Run '%s' not found" % nonexistent_uuid) in e.exception.message
Ejemplo n.º 3
0
 def __init__(self, message, error_code=INTERNAL_ERROR, **kwargs):
     """
     :param message: The message describing the error that occured. This will be included in the
                     exception's serialized JSON representation.
     :param error_code: An appropriate error code for the error that occured; it will be included
                        in the exception's serialized JSON representation. This should be one of
                        the codes listed in the `mlflow.protos.databricks_pb2` proto.
     :param kwargs: Additional key-value pairs to include in the serialized JSON representation
                    of the MlflowException.
     """
     try:
         self.error_code = ErrorCode.Name(error_code)
     except (ValueError, TypeError):
         self.error_code = ErrorCode.Name(INTERNAL_ERROR)
     self.message = message
     self.json_kwargs = kwargs
     super(MlflowException, self).__init__(message)
Ejemplo n.º 4
0
def test_log_metric_validation(tracking_uri_mock):
    with start_run() as active_run:
        run_id = active_run.info.run_id
        with pytest.raises(MlflowException) as e:
            mlflow.log_metric("name_1", "apple")
    assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
    finished_run = tracking.MlflowClient().get_run(run_id)
    assert len(finished_run.data.metrics) == 0
Ejemplo n.º 5
0
def test_set_experiment_with_deleted_experiment():
    name = "dead_exp"
    mlflow.set_experiment(name)
    with start_run() as run:
        exp_id = run.info.experiment_id

    tracking.MlflowClient().delete_experiment(exp_id)

    with pytest.raises(MlflowException,
                       match="Cannot set a deleted experiment") as exc:
        mlflow.set_experiment(name)
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

    with pytest.raises(MlflowException,
                       match="Cannot set a deleted experiment") as exc:
        mlflow.set_experiment(experiment_id=exp_id)
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 6
0
def test_validate_param_name():
    for good_name in GOOD_METRIC_OR_PARAM_NAMES:
        _validate_param_name(good_name)
    for bad_name in BAD_METRIC_OR_PARAM_NAMES:
        with pytest.raises(MlflowException,
                           match="Invalid parameter name") as e:
            _validate_param_name(bad_name)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 7
0
def test_validate_run_id():
    for good_id in ["a" * 32, "f0" * 16, "abcdef0123456789" * 2, "a" * 33, "a" * 31,
                    "a" * 256, "A" * 32, "g" * 32, "a_" * 32, "abcdefghijklmnopqrstuvqxyz"]:
        _validate_run_id(good_id)
    for bad_id in ["a/bc" * 8, "", "a" * 400, "*" * 5]:
        with pytest.raises(MlflowException, match="Invalid run ID") as e:
            _validate_run_id(bad_id)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 8
0
def test_deployment_with_missing_flavor_raises_exception(pretrained_model):
    missing_flavor = "mleap"
    with pytest.raises(MlflowException) as exc:
        mfs.deploy(app_name="missing-flavor",
                   model_path=pretrained_model.model_path,
                   run_id=pretrained_model.run_id,
                   flavor=missing_flavor)

    assert exc.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
Ejemplo n.º 9
0
def test_deployment_with_unsupported_flavor_raises_exception(pretrained_model):
    unsupported_flavor = "this is not a valid flavor"
    with pytest.raises(MlflowException) as exc:
        mfs.deploy(app_name="bad_flavor",
                   model_path=pretrained_model.model_path,
                   run_id=pretrained_model.run_id,
                   flavor=unsupported_flavor)

    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 10
0
def call_endpoints(host_creds, endpoints, json_body, response_proto):
    # The order that the endpoints are called in is defined by the order
    # specified in ModelRegistryService in model_registry.proto
    for i, (endpoint, method) in enumerate(endpoints):
        try:
            return call_endpoint(host_creds, endpoint, method, json_body, response_proto)
        except RestException as e:
            if e.error_code != ErrorCode.Name(ENDPOINT_NOT_FOUND) or i == len(endpoints) - 1:
                raise e
Ejemplo n.º 11
0
def test_deployment_with_unsupported_flavor_raises_exception(pretrained_model):
    unsupported_flavor = "this is not a valid flavor"
    match = "The specified flavor: `this is not a valid flavor` is not supported for deployment"
    with pytest.raises(MlflowException, match=match) as exc:
        mfs.deploy(app_name="bad_flavor",
                   model_uri=pretrained_model.model_uri,
                   flavor=unsupported_flavor)

    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 12
0
def test_deployment_with_missing_flavor_raises_exception(pretrained_model):
    missing_flavor = "mleap"
    match = "The specified model does not contain the specified deployment flavor"
    with pytest.raises(MlflowException, match=match) as exc:
        mfs.deploy(app_name="missing-flavor",
                   model_uri=pretrained_model.model_uri,
                   flavor=missing_flavor)

    assert exc.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
def test_delete_deployment_in_asynchronous_mode_without_archiving_raises_exception(
    sagemaker_deployment_client,
):
    with pytest.raises(MlflowException, match="Resources must be archived") as exc:
        sagemaker_deployment_client.delete_deployment(
            name="dummy", config=dict(archive=False, synchronous=False)
        )

    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 14
0
def test_run_databricks_throws_exception_when_spec_uses_existing_cluster():
    with mock.patch.dict(os.environ, {"DATABRICKS_HOST": "test-host", "DATABRICKS_TOKEN": "foo"}):
        existing_cluster_spec = {
            "existing_cluster_id": "1000-123456-clust1",
        }
        with pytest.raises(MlflowException) as exc:
            run_databricks_project(cluster_spec=existing_cluster_spec)
        assert "execution against existing clusters is not currently supported" in exc.value.message
        assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 15
0
    def test_search_registered_model_pagination(self):
        rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]

        # test flow with fixed max_results
        returned_rms = []
        query = "name LIKE 'RM%'"
        result, token = self._search_registered_models(query, page_token=None, max_results=5)
        returned_rms.extend(result)
        while token:
            result, token = self._search_registered_models(query, page_token=token, max_results=5)
            returned_rms.extend(result)
        self.assertEqual(rms, returned_rms)

        # test that pagination will return all valid results in sorted order
        # by name ascending
        result, token1 = self._search_registered_models(query, max_results=5)
        self.assertNotEqual(token1, None)
        self.assertEqual(result, rms[0:5])

        result, token2 = self._search_registered_models(query, page_token=token1, max_results=10)
        self.assertNotEqual(token2, None)
        self.assertEqual(result, rms[5:15])

        result, token3 = self._search_registered_models(query, page_token=token2, max_results=20)
        self.assertNotEqual(token3, None)
        self.assertEqual(result, rms[15:35])

        result, token4 = self._search_registered_models(query, page_token=token3, max_results=100)
        # assert that page token is None
        self.assertEqual(token4, None)
        self.assertEqual(result, rms[35:])

        # test that providing a completely invalid page token throws
        with self.assertRaises(MlflowException) as exception_context:
            self._search_registered_models(query, page_token="evilhax", max_results=20)
        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        # test that providing too large of a max_results throws
        with self.assertRaises(MlflowException) as exception_context:
            self._search_registered_models(query, page_token="evilhax", max_results=1e15)
            assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
        self.assertIn("Invalid value for request parameter max_results",
                      exception_context.exception.message)
Ejemplo n.º 16
0
    def test_log_null_param(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        param = entities.Param(tkey, tval)

        with self.assertRaises(MlflowException) as exception_context:
            self.store.log_param(run.info.run_uuid, param)
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
Ejemplo n.º 17
0
    def test_log_null_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        with self.assertRaises(MlflowException) as exception_context:
            self.store.log_metric(run.info.run_uuid, metric)
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
Ejemplo n.º 18
0
def test_log_batch_duplicate_entries_raises():
    with start_run() as active_run:
        run_id = active_run.info.run_id
        with pytest.raises(
            MlflowException, match=r"Duplicate parameter keys have been submitted."
        ) as e:
            tracking.MlflowClient().log_batch(
                run_id=run_id, params=[Param("a", "1"), Param("a", "2")]
            )
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 19
0
def test_log_batch_validates_entity_names_and_values():
    with start_run() as active_run:
        run_id = active_run.info.run_id

        metrics = [
            Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)
        ]
        with pytest.raises(MlflowException, match="Invalid metric name") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        metrics = [
            Metric(key="ok-name",
                   value="non-numerical-value",
                   timestamp=3,
                   step=0)
        ]
        with pytest.raises(MlflowException, match="Got invalid value") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        metrics = [
            Metric(key="ok-name",
                   value=0.3,
                   timestamp="non-numerical-timestamp",
                   step=0)
        ]
        with pytest.raises(MlflowException,
                           match="Got invalid timestamp") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        params = [Param(key="../bad/param/name", value="my-val")]
        with pytest.raises(MlflowException,
                           match="Invalid parameter name") as e:
            tracking.MlflowClient().log_batch(run_id, params=params)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        tags = [Param(key="../bad/tag/name", value="my-val")]
        with pytest.raises(MlflowException, match="Invalid tag name") as e:
            tracking.MlflowClient().log_batch(run_id, tags=tags)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 20
0
def test_catch_mlflow_exception():
    @catch_mlflow_exception
    def test_handler():
        raise MlflowException('test error', error_code=INTERNAL_ERROR)

    # pylint: disable=assignment-from-no-return
    response = test_handler()
    json_response = json.loads(response.get_data())
    assert response.status_code == 500
    assert json_response['error_code'] == ErrorCode.Name(INTERNAL_ERROR)
    assert json_response['message'] == 'test error'
Ejemplo n.º 21
0
def test_attempting_to_deploy_in_asynchronous_mode_without_archiving_throws_exception(
        pretrained_model):
    with pytest.raises(MlflowException) as exc:
        mfs.deploy(app_name="test-app",
                   model_uri=pretrained_model.model_uri,
                   mode=mfs.DEPLOYMENT_MODE_CREATE,
                   archive=False,
                   synchronous=False)

    assert "Resources must be archived" in exc.value.message
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
def test_update_deployment_with_create_mode_raises_exception(
    pretrained_model, sagemaker_deployment_client
):
    with pytest.raises(MlflowException, match="Invalid mode") as exc:
        sagemaker_deployment_client.update_deployment(
            name="invalid mode",
            model_uri=pretrained_model.model_uri,
            config=dict(mode=mfs.DEPLOYMENT_MODE_CREATE),
        )

    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 23
0
def test_set_experiment_parameter_validation():
    with pytest.raises(MlflowException,
                       match="Must specify exactly one") as exc:
        mlflow.set_experiment()
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

    with pytest.raises(MlflowException,
                       match="Must specify exactly one") as exc:
        mlflow.set_experiment(None)
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

    with pytest.raises(MlflowException,
                       match="Must specify exactly one") as exc:
        mlflow.set_experiment(None, None)
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

    with pytest.raises(MlflowException,
                       match="Must specify exactly one") as exc:
        mlflow.set_experiment("name", "id")
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 24
0
 def test_search_registered_model_order_by_errors(self):
     query = "name LIKE 'RM%'"
     # test that invalid columns throw even if they come after valid columns
     with self.assertRaises(MlflowException) as exception_context:
         self._search_registered_models(
             query,
             page_token=None,
             order_by=['name ASC', 'creation_timestamp DESC'],
             max_results=5)
     assert exception_context.exception.error_code == ErrorCode.Name(
         INVALID_PARAMETER_VALUE)
     # test that invalid columns with random text throw even if they come after valid columns
     with self.assertRaises(MlflowException) as exception_context:
         self._search_registered_models(
             query,
             page_token=None,
             order_by=['name ASC', 'last_updated_timestamp DESC blah'],
             max_results=5)
     assert exception_context.exception.error_code == ErrorCode.Name(
         INVALID_PARAMETER_VALUE)
Ejemplo n.º 25
0
    def test_list_registered_model_paginated_errors(self):
        rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]
        # test that providing a completely invalid page token throws
        with self.assertRaises(MlflowException) as exception_context:
            self._list_registered_models(page_token="evilhax", max_results=20)
        assert exception_context.exception.error_code == ErrorCode.Name(
            INVALID_PARAMETER_VALUE)

        # test that providing too large of a max_results throws
        with self.assertRaises(MlflowException) as exception_context:
            self._list_registered_models(page_token="evilhax",
                                         max_results=1e15)
            assert exception_context.exception.error_code == ErrorCode.Name(
                INVALID_PARAMETER_VALUE)
        self.assertIn("Invalid value for request parameter max_results",
                      exception_context.exception.message)
        # list should not return deleted models
        self.store.delete_registered_model(name=f"RM{0:03}")
        self.assertEqual(set(self._list_registered_models(max_results=100)),
                         set(rms[1:]))
Ejemplo n.º 26
0
 def test_run_needs_uuid(self):
     # Depending on the implementation, a NULL identity key may result in different
     # exceptions, including IntegrityError (sqlite) and FlushError (MysQL).
     # Therefore, we check for the more generic 'SQLAlchemyError'
     with self.assertRaises(MlflowException) as exception_context:
         warnings.simplefilter("ignore")
         with self.store.ManagedSessionMaker() as session, warnings.catch_warnings():
             run = models.SqlRun()
             session.add(run)
             warnings.resetwarnings()
     assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
def test_attempting_to_deploy_in_asynchronous_mode_without_archiving_throws_exception(
    pretrained_model, sagemaker_deployment_client
):
    with pytest.raises(MlflowException, match="Resources must be archived") as exc:
        sagemaker_deployment_client.create_deployment(
            name="test-app",
            model_uri=pretrained_model.model_uri,
            config=dict(archive=False, synchronous=False),
        )

    assert "Resources must be archived" in exc.value.message
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Ejemplo n.º 28
0
 def test_log_param_enforces_value_immutability(self):
     param_name = "new param"
     fs = FileStore(self.test_root)
     run_id = self.exp_data[FileStore.DEFAULT_EXPERIMENT_ID]["runs"][0]
     fs.log_param(run_id, Param(param_name, "value1"))
     # Duplicate calls to `log_param` with the same key and value should succeed
     fs.log_param(run_id, Param(param_name, "value1"))
     with pytest.raises(MlflowException) as exc:
         fs.log_param(run_id, Param(param_name, "value2"))
     assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
     run = fs.get_run(run_id)
     assert run.data.params[param_name] == "value1"
Ejemplo n.º 29
0
def test_deployment_of_model_with_no_supported_flavors_raises_exception(pretrained_model):
    logged_model_path = _download_artifact_from_uri(pretrained_model.model_uri)
    model_config_path = os.path.join(logged_model_path, "MLmodel")
    model_config = Model.load(model_config_path)
    del model_config.flavors[mlflow.pyfunc.FLAVOR_NAME]
    model_config.save(path=model_config_path)

    match = "The specified model does not contain any of the supported flavors for deployment"
    with pytest.raises(MlflowException, match=match) as exc:
        mfs.deploy(app_name="missing-flavor", model_uri=logged_model_path, flavor=None)

    assert exc.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
Ejemplo n.º 30
0
    def test_log_null_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        warnings.simplefilter("ignore")
        with self.assertRaises(MlflowException) as exception_context, warnings.catch_warnings():
            self.store.log_metric(run.info.run_uuid, metric)
            warnings.resetwarnings()
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)