예제 #1
0
def test_get_non_existing_model(firebase_model):
    # Get a valid model_id that no longer exists
    ml.delete_model(firebase_model.model_id)

    with pytest.raises(exceptions.NotFoundError) as excinfo:
        ml.get_model(firebase_model.model_id)
    check_firebase_error(excinfo, 404, 'Requested entity was not found.')
예제 #2
0
 def test_get_model_error(self):
     recorder = instrument_ml_service(status=404,
                                      payload=ERROR_RESPONSE_NOT_FOUND)
     with pytest.raises(exceptions.NotFoundError) as excinfo:
         ml.get_model(MODEL_ID_1)
     check_firebase_error(excinfo, ERROR_STATUS_NOT_FOUND,
                          ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND)
     assert len(recorder) == 1
     assert recorder[0].method == 'GET'
     assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1)
     assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
예제 #3
0
def update_model(model_id,
                 model_file=None,
                 name=None,
                 new_tags=None,
                 remove_tags=None):
    """Update one of the project's models."""
    model = ml.get_model(model_id)

    if model_file is not None:
        # Load a tflite file and upload it to Cloud Storage
        print('Uploading to Cloud Storage...')
        model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(
            model_file)
        tflite_format = ml.TFLiteFormat(model_source=model_source)
        model.model_format = tflite_format

    if name is not None:
        model.display_name = name

    if new_tags is not None:
        model.tags = new_tags if model.tags is None else model.tags + new_tags

    if remove_tags is not None and model.tags is not None:
        model.tags = list(set(model.tags).difference(set(remove_tags)))

    updated_model = ml.update_model(model)
    ml.publish_model(updated_model.model_id)
예제 #4
0
 def test_get_model(self):
     recorder = instrument_ml_service(status=200,
                                      payload=DEFAULT_GET_RESPONSE)
     model = ml.get_model(MODEL_ID_1)
     assert len(recorder) == 1
     assert recorder[0].method == 'GET'
     assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1)
     assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
     assert model == MODEL_1
     assert model.model_id == MODEL_ID_1
     assert model.display_name == DISPLAY_NAME_1
예제 #5
0
def get_model_info(model_id):
    """Get model details."""
    model = ml.get_model(model_id)
    created = datetime.fromtimestamp(model.create_time / 1000, timezone.utc)
    updated = datetime.fromtimestamp(model.update_time / 1000, timezone.utc)
    table = BeautifulTable()
    table.columns.append([
        'Name:', 'ID:', 'Tags:', 'Published:', 'ETag:', 'SHA256:', 'Created:',
        'Updated:'
    ])
    table.columns.append([
        model.display_name, model.model_id,
        ', '.join(model.tags) if model.tags else '',
        'Yes' if model.published else 'No', model.etag, model.model_hash,
        created.isoformat(' ', timespec='seconds'),
        updated.isoformat(' ', timespec='seconds')
    ])
    table.set_style(BeautifulTable.STYLE_COMPACT)
    table.columns.alignment = BeautifulTable.ALIGN_LEFT
    print(table)
예제 #6
0
def test_get_model(firebase_model):
    get_model = ml.get_model(firebase_model.model_id)
    check_model(get_model, NAME_AND_TAGS_ARGS)
    check_no_model_format(get_model)
예제 #7
0
 def evaluate():
     app = firebase_admin.initialize_app(testutils.MockCredential(),
                                         name='no_project_id')
     with pytest.raises(ValueError):
         ml.get_model(MODEL_ID_1, app)
예제 #8
0
 def test_get_model_validation_errors(self, model_id, exc_type):
     with pytest.raises(exc_type) as excinfo:
         ml.get_model(model_id)
     check_error(excinfo, exc_type)