def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 assert model_format.as_dict() == { 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_URI_2 } }
def publish_model_to_firebase(tflite_model_name, model_name): source = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_model_name) model_format = ml.TFLiteFormat(model_source=source) firebase_models = ml.list_models( list_filter="display_name = {0}".format(model_name)).iterate_all() for model in firebase_models: custom_model = model custom_model.model_format = model_format model_to_publish = ml.update_model(custom_model) ml.publish_model(model_to_publish.model_id)
def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict(for_upload=True) == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI } }
def test_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_URI } }
def firebase_model(request): args = request.param tflite_format = None file_name = args.get('file_name') if file_name: file_path = testutils.resource_filename(file_name) source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) tflite_format = ml.TFLiteFormat(model_source=source) ml_model = ml.Model(display_name=args.get('display_name'), tags=args.get('tags'), model_format=tflite_format) model = ml.create_model(model=ml_model) yield model _clean_up_model(model)
def test_from_keras_model(keras_model): source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') assert re.search( '^gs://.*/Firebase/ML/Models/model2.tflite$', source.gcs_tflite_uri) is not None # Validate the conversion by creating a model model_format = ml.TFLiteFormat(model_source=source) model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) created_model = ml.create_model(model) try: check_model(created_model, {'display_name': model.display_name}) check_tflite_gcs_format(created_model) finally: _clean_up_model(created_model)
def add_automl_model(model_ref, name, tags=None): """Add an AutoML tflite model file to the project and publish it.""" # Create the model object model_source = ml.TFLiteAutoMlSource(model_ref) model = ml.Model(display_name=name, model_format=ml.TFLiteFormat(model_source=model_source)) if tags is not None: model.tags = tags # Add the model to your Firebase project and publish it new_model = ml.create_model(model) new_model.wait_for_unlocked() ml.publish_model(new_model.model_id) print('Model uploaded and published:') print_models([new_model], headers=False)
def test_from_saved_model(saved_model_dir): # Test the conversion helper source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') assert re.search( '^gs://.*/Firebase/ML/Models/model3.tflite$', source.gcs_tflite_uri) is not None # Validate the conversion by creating a model model_format = ml.TFLiteFormat(model_source=source) model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) created_model = ml.create_model(model) try: assert created_model.model_id is not None assert created_model.validation_error is None finally: _clean_up_model(created_model)
def upload_model(model_file, name, tags=None): """Upload a tflite model file to the project and publish it.""" # Load a tflite file and upload it to Cloud Storage print('Uploading to Cloud Storage...') model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file) # Create the model object tflite_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=name, model_format=tflite_format) if tags is not None: model.tags = tags # Add the model to your Firebase project and publish it new_model = ml.create_model(model) ml.publish_model(new_model.model_id) print('Model uploaded and published:') print_models([new_model], headers=False)
def test_model_source_validation_errors(self, model_source): with pytest.raises(TypeError) as excinfo: ml.TFLiteFormat(model_source=model_source) check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.')