def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
def test_rpc_error_create(self): create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_firebase_error(excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST) assert len(create_recorder) == 1
def model_list(): ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_')) model_1 = ml.create_model(model=ml_model_1) ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'), tags=['test_tag123']) model_2 = ml.create_model(model=ml_model_2) yield [model_1, model_2] _clean_up_model(model_1) _clean_up_model(model_2)
def automl_model(): assert _AUTOML_ENABLED # It takes > 20 minutes to train a model, so we expect a predefined AutoMl # model named 'admin_sdk_integ_test1' to exist in the project, or we skip # the test. automl_client = automl_v1.AutoMlClient() project_id = firebase_admin.get_app().project_id parent = automl_client.location_path(project_id, 'us-central1') models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") # Expecting exactly one. (Ok to use last one if somehow more than 1) automl_ref = None for model in models: automl_ref = model.name # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) if automl_ref is None: pytest.skip("No pre-existing AutoML model found. Skipping test") source = ml.TFLiteAutoMlSource(automl_ref) tflite_format = ml.TFLiteFormat(model_source=source) ml_model = ml.Model( display_name=_random_identifier('TestModel_automl_'), tags=['test_automl'], model_format=tflite_format) model = ml.create_model(model=ml_model) yield model _clean_up_model(model)
def firebase_model(request): args = request.param tflite_format = None file_name = args.get('file_name') if file_name: file_path = testutils.resource_filename(file_name) source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) tflite_format = ml.TFLiteFormat(model_source=source) ml_model = ml.Model(display_name=args.get('display_name'), tags=args.get('tags'), model_format=tflite_format) model = ml.create_model(model=ml_model) yield model _clean_up_model(model)
def add_automl_model(model_ref, name, tags=None): """Add an AutoML tflite model file to the project and publish it.""" # Create the model object model_source = ml.TFLiteAutoMlSource(model_ref) model = ml.Model(display_name=name, model_format=ml.TFLiteFormat(model_source=model_source)) if tags is not None: model.tags = tags # Add the model to your Firebase project and publish it new_model = ml.create_model(model) new_model.wait_for_unlocked() ml.publish_model(new_model.model_id) print('Model uploaded and published:') print_models([new_model], headers=False)
def test_returns_locked(self): recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = ml.create_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'POST' assert recorder[0].url == TestCreateModel._url(PROJECT_ID) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestCreateModel._get_url( PROJECT_ID, MODEL_ID_1) assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
def test_from_keras_model(keras_model): source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') assert re.search( '^gs://.*/Firebase/ML/Models/model2.tflite$', source.gcs_tflite_uri) is not None # Validate the conversion by creating a model model_format = ml.TFLiteFormat(model_source=source) model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) created_model = ml.create_model(model) try: check_model(created_model, {'display_name': model.display_name}) check_tflite_gcs_format(created_model) finally: _clean_up_model(created_model)
def test_from_saved_model(saved_model_dir): # Test the conversion helper source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') assert re.search( '^gs://.*/Firebase/ML/Models/model3.tflite$', source.gcs_tflite_uri) is not None # Validate the conversion by creating a model model_format = ml.TFLiteFormat(model_source=source) model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) created_model = ml.create_model(model) try: assert created_model.model_id is not None assert created_model.validation_error is None finally: _clean_up_model(created_model)
def upload_model(model_file, name, tags=None): """Upload a tflite model file to the project and publish it.""" # Load a tflite file and upload it to Cloud Storage print('Uploading to Cloud Storage...') model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file) # Create the model object tflite_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=name, model_format=tflite_format) if tags is not None: model.tags = tags # Add the model to your Firebase project and publish it new_model = ml.create_model(model) ml.publish_model(new_model.model_id) print('Model uploaded and published:') print_models([new_model], headers=False)
def test_missing_op_name(self): instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_error(excinfo, TypeError)
def test_create_already_existing_fails(firebase_model): with pytest.raises(exceptions.AlreadyExistsError) as excinfo: ml.create_model(model=firebase_model) check_operation_error( excinfo, 'Model \'{0}\' already exists'.format(firebase_model.display_name))
def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.')
def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = ml.create_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1
def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: ml.create_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.')
def test_not_model(self, model): with pytest.raises(Exception) as excinfo: ml.create_model(model) check_error(excinfo, TypeError, 'Model must be an ml.Model.')
def test_malformed_operation(self): instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')