def test_operation_error(self):
     instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE)
     with pytest.raises(Exception) as excinfo:
         ml.create_model(MODEL_1)
     # The http request succeeded, the operation returned contains a create failure
     check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS,
                           OPERATION_ERROR_MSG)
 def test_rpc_error_create(self):
     create_recorder = instrument_ml_service(
         status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
     with pytest.raises(Exception) as excinfo:
         ml.create_model(MODEL_1)
     check_firebase_error(excinfo, ERROR_STATUS_BAD_REQUEST,
                          ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST)
     assert len(create_recorder) == 1
Example #3
0
def model_list():
    ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_'))
    model_1 = ml.create_model(model=ml_model_1)

    ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'),
                          tags=['test_tag123'])
    model_2 = ml.create_model(model=ml_model_2)

    yield [model_1, model_2]

    _clean_up_model(model_1)
    _clean_up_model(model_2)
Example #4
0
def automl_model():
    assert _AUTOML_ENABLED

    # It takes > 20 minutes to train a model, so we expect a predefined AutoMl
    # model named 'admin_sdk_integ_test1' to exist in the project, or we skip
    # the test.
    automl_client = automl_v1.AutoMlClient()
    project_id = firebase_admin.get_app().project_id
    parent = automl_client.location_path(project_id, 'us-central1')
    models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
    # Expecting exactly one. (Ok to use last one if somehow more than 1)
    automl_ref = None
    for model in models:
        automl_ref = model.name

    # Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
    if automl_ref is None:
        pytest.skip("No pre-existing AutoML model found. Skipping test")

    source = ml.TFLiteAutoMlSource(automl_ref)
    tflite_format = ml.TFLiteFormat(model_source=source)
    ml_model = ml.Model(
        display_name=_random_identifier('TestModel_automl_'),
        tags=['test_automl'],
        model_format=tflite_format)
    model = ml.create_model(model=ml_model)
    yield model
    _clean_up_model(model)
Example #5
0
def firebase_model(request):
    args = request.param
    tflite_format = None
    file_name = args.get('file_name')
    if file_name:
        file_path = testutils.resource_filename(file_name)
        source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path)
        tflite_format = ml.TFLiteFormat(model_source=source)

    ml_model = ml.Model(display_name=args.get('display_name'),
                        tags=args.get('tags'),
                        model_format=tflite_format)
    model = ml.create_model(model=ml_model)
    yield model
    _clean_up_model(model)
Example #6
0
def add_automl_model(model_ref, name, tags=None):
    """Add an AutoML tflite model file to the project and publish it."""
    # Create the model object
    model_source = ml.TFLiteAutoMlSource(model_ref)
    model = ml.Model(display_name=name,
                     model_format=ml.TFLiteFormat(model_source=model_source))
    if tags is not None:
        model.tags = tags

    # Add the model to your Firebase project and publish it
    new_model = ml.create_model(model)
    new_model.wait_for_unlocked()
    ml.publish_model(new_model.model_id)

    print('Model uploaded and published:')
    print_models([new_model], headers=False)
    def test_returns_locked(self):
        recorder = instrument_ml_service(
            status=[200, 200],
            payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
        expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2)
        model = ml.create_model(MODEL_1)

        assert model == expected_model
        assert len(recorder) == 2
        assert recorder[0].method == 'POST'
        assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
        assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
        assert recorder[1].method == 'GET'
        assert recorder[1].url == TestCreateModel._get_url(
            PROJECT_ID, MODEL_ID_1)
        assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
Example #8
0
def test_from_keras_model(keras_model):
    source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
    assert re.search(
        '^gs://.*/Firebase/ML/Models/model2.tflite$',
        source.gcs_tflite_uri) is not None

    # Validate the conversion by creating a model
    model_format = ml.TFLiteFormat(model_source=source)
    model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format)
    created_model = ml.create_model(model)

    try:
        check_model(created_model, {'display_name': model.display_name})
        check_tflite_gcs_format(created_model)
    finally:
        _clean_up_model(created_model)
Example #9
0
def test_from_saved_model(saved_model_dir):
    # Test the conversion helper
    source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite')
    assert re.search(
        '^gs://.*/Firebase/ML/Models/model3.tflite$',
        source.gcs_tflite_uri) is not None

    # Validate the conversion by creating a model
    model_format = ml.TFLiteFormat(model_source=source)
    model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format)
    created_model = ml.create_model(model)

    try:
        assert created_model.model_id is not None
        assert created_model.validation_error is None
    finally:
        _clean_up_model(created_model)
Example #10
0
def upload_model(model_file, name, tags=None):
    """Upload a tflite model file to the project and publish it."""
    # Load a tflite file and upload it to Cloud Storage
    print('Uploading to Cloud Storage...')
    model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file)

    # Create the model object
    tflite_format = ml.TFLiteFormat(model_source=model_source)
    model = ml.Model(display_name=name, model_format=tflite_format)
    if tags is not None:
        model.tags = tags

    # Add the model to your Firebase project and publish it
    new_model = ml.create_model(model)
    ml.publish_model(new_model.model_id)

    print('Model uploaded and published:')
    print_models([new_model], headers=False)
Example #11
0
 def test_missing_op_name(self):
     instrument_ml_service(status=200,
                           payload=OPERATION_MISSING_NAME_RESPONSE)
     with pytest.raises(Exception) as excinfo:
         ml.create_model(MODEL_1)
     check_error(excinfo, TypeError)
Example #12
0
def test_create_already_existing_fails(firebase_model):
    with pytest.raises(exceptions.AlreadyExistsError) as excinfo:
        ml.create_model(model=firebase_model)
    check_operation_error(
        excinfo,
        'Model \'{0}\' already exists'.format(firebase_model.display_name))
Example #13
0
 def test_invalid_op_name(self, op_name):
     payload = json.dumps({'name': op_name})
     instrument_ml_service(status=200, payload=payload)
     with pytest.raises(Exception) as excinfo:
         ml.create_model(MODEL_1)
     check_error(excinfo, ValueError, 'Operation name format is invalid.')
Example #14
0
 def test_immediate_done(self):
     instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE)
     model = ml.create_model(MODEL_1)
     assert model == CREATED_UPDATED_MODEL_1
Example #15
0
 def test_missing_display_name(self):
     with pytest.raises(Exception) as excinfo:
         ml.create_model(ml.Model.from_dict({}))
     check_error(excinfo, ValueError, 'Model must have a display name.')
Example #16
0
 def test_not_model(self, model):
     with pytest.raises(Exception) as excinfo:
         ml.create_model(model)
     check_error(excinfo, TypeError, 'Model must be an ml.Model.')
Example #17
0
 def test_malformed_operation(self):
     instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
     with pytest.raises(Exception) as excinfo:
         ml.create_model(MODEL_1)
     check_error(excinfo, exceptions.UnknownError,
                 'Internal Error: Malformed Operation.')