Ejemplo n.º 1
0
 def test_model_format_setters(self):
     model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE)
     model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2
     assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2
     assert model_format.as_dict() == {
         'tfliteModel': {
             'gcsTfliteUri': GCS_TFLITE_URI_2
         }
     }
Ejemplo n.º 2
0
def publish_model_to_firebase(tflite_model_name, model_name):
    source = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_model_name)
    model_format = ml.TFLiteFormat(model_source=source)
    firebase_models = ml.list_models(
        list_filter="display_name = {0}".format(model_name)).iterate_all()
    for model in firebase_models:
        custom_model = model

    custom_model.model_format = model_format
    model_to_publish = ml.update_model(custom_model)
    ml.publish_model(model_to_publish.model_id)
Ejemplo n.º 3
0
 def test_model_as_dict_for_upload(self):
     model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
     model_format = ml.TFLiteFormat(model_source=model_source)
     model = ml.Model(display_name=DISPLAY_NAME_1,
                      model_format=model_format)
     assert model.as_dict(for_upload=True) == {
         'displayName': DISPLAY_NAME_1,
         'tfliteModel': {
             'gcsTfliteUri': GCS_TFLITE_SIGNED_URI
         }
     }
Ejemplo n.º 4
0
 def test_model_format_source_creation(self):
     model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
     model_format = ml.TFLiteFormat(model_source=model_source)
     model = ml.Model(display_name=DISPLAY_NAME_1,
                      model_format=model_format)
     assert model.as_dict() == {
         'displayName': DISPLAY_NAME_1,
         'tfliteModel': {
             'gcsTfliteUri': GCS_TFLITE_URI
         }
     }
Ejemplo n.º 5
0
def firebase_model(request):
    args = request.param
    tflite_format = None
    file_name = args.get('file_name')
    if file_name:
        file_path = testutils.resource_filename(file_name)
        source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path)
        tflite_format = ml.TFLiteFormat(model_source=source)

    ml_model = ml.Model(display_name=args.get('display_name'),
                        tags=args.get('tags'),
                        model_format=tflite_format)
    model = ml.create_model(model=ml_model)
    yield model
    _clean_up_model(model)
Ejemplo n.º 6
0
def test_from_keras_model(keras_model):
    source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
    assert re.search(
        '^gs://.*/Firebase/ML/Models/model2.tflite$',
        source.gcs_tflite_uri) is not None

    # Validate the conversion by creating a model
    model_format = ml.TFLiteFormat(model_source=source)
    model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format)
    created_model = ml.create_model(model)

    try:
        check_model(created_model, {'display_name': model.display_name})
        check_tflite_gcs_format(created_model)
    finally:
        _clean_up_model(created_model)
Ejemplo n.º 7
0
def add_automl_model(model_ref, name, tags=None):
    """Add an AutoML tflite model file to the project and publish it."""
    # Create the model object
    model_source = ml.TFLiteAutoMlSource(model_ref)
    model = ml.Model(display_name=name,
                     model_format=ml.TFLiteFormat(model_source=model_source))
    if tags is not None:
        model.tags = tags

    # Add the model to your Firebase project and publish it
    new_model = ml.create_model(model)
    new_model.wait_for_unlocked()
    ml.publish_model(new_model.model_id)

    print('Model uploaded and published:')
    print_models([new_model], headers=False)
Ejemplo n.º 8
0
def test_from_saved_model(saved_model_dir):
    # Test the conversion helper
    source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite')
    assert re.search(
        '^gs://.*/Firebase/ML/Models/model3.tflite$',
        source.gcs_tflite_uri) is not None

    # Validate the conversion by creating a model
    model_format = ml.TFLiteFormat(model_source=source)
    model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format)
    created_model = ml.create_model(model)

    try:
        assert created_model.model_id is not None
        assert created_model.validation_error is None
    finally:
        _clean_up_model(created_model)
Ejemplo n.º 9
0
def upload_model(model_file, name, tags=None):
    """Upload a tflite model file to the project and publish it."""
    # Load a tflite file and upload it to Cloud Storage
    print('Uploading to Cloud Storage...')
    model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file)

    # Create the model object
    tflite_format = ml.TFLiteFormat(model_source=model_source)
    model = ml.Model(display_name=name, model_format=tflite_format)
    if tags is not None:
        model.tags = tags

    # Add the model to your Firebase project and publish it
    new_model = ml.create_model(model)
    ml.publish_model(new_model.model_id)

    print('Model uploaded and published:')
    print_models([new_model], headers=False)
Ejemplo n.º 10
0
 def test_model_source_validation_errors(self, model_source):
     with pytest.raises(TypeError) as excinfo:
         ml.TFLiteFormat(model_source=model_source)
     check_error(excinfo, TypeError,
                 'Model source must be a TFLiteModelSource object.')