Exemplo n.º 1
0
 def test_list_models_error(self):
     recorder = instrument_ml_service(status=400,
                                      payload=ERROR_RESPONSE_BAD_REQUEST)
     with pytest.raises(exceptions.InvalidArgumentError) as excinfo:
         ml.list_models()
     check_firebase_error(excinfo, ERROR_STATUS_BAD_REQUEST,
                          ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST)
     assert len(recorder) == 1
     assert recorder[0].method == 'GET'
     assert recorder[0].url == TestListModels._url(PROJECT_ID)
     assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
Exemplo n.º 2
0
def list_models(filter_exp=''):
    """List the models in the project."""
    models = ml.list_models(list_filter=filter_exp).iterate_all()
    for model in models:
        tags = ', '.join(model.tags) if model.tags is not None else ''
        print('{:<20}{:<10} {}'.format(model.display_name, model.model_id,
                                       tags))
Exemplo n.º 3
0
def test_list_models(model_list):
    filter_str = 'displayName={0} OR tags:{1}'.format(
        model_list[0].display_name, model_list[1].tags[0])

    all_models = ml.list_models(list_filter=filter_str)
    all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()]
    for mdl in model_list:
        assert mdl.model_id in all_model_ids
Exemplo n.º 4
0
 def test_list_models_no_models(self):
     recorder = instrument_ml_service(status=200,
                                      payload=NO_MODELS_LIST_RESPONSE)
     page = ml.list_models()
     assert len(recorder) == 1
     assert len(page.models) == 0
     models = [model for model in page.iterate_all()]
     assert len(models) == 0
Exemplo n.º 5
0
 def test_list_single_page(self):
     recorder = instrument_ml_service(status=200,
                                      payload=LAST_PAGE_LIST_RESPONSE)
     models_page = ml.list_models()
     assert len(recorder) == 1
     assert models_page.next_page_token == ''
     assert models_page.has_next_page is False
     assert models_page.get_next_page() is None
     models = [model for model in models_page.iterate_all()]
     assert len(models) == 1
Exemplo n.º 6
0
def publish_model_to_firebase(tflite_model_name, model_name):
    source = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_model_name)
    model_format = ml.TFLiteFormat(model_source=source)
    firebase_models = ml.list_models(
        list_filter="display_name = {0}".format(model_name)).iterate_all()
    for model in firebase_models:
        custom_model = model

    custom_model.model_format = model_format
    model_to_publish = ml.update_model(custom_model)
    ml.publish_model(model_to_publish.model_id)
Exemplo n.º 7
0
 def test_list_models_stop_iteration(self):
     recorder = instrument_ml_service(status=200,
                                      payload=ONE_PAGE_LIST_RESPONSE)
     page = ml.list_models()
     assert len(recorder) == 1
     assert len(page.models) == 3
     iterator = page.iterate_all()
     models = [model for model in iterator]
     assert len(page.models) == 3
     with pytest.raises(StopIteration):
         next(iterator)
     assert len(models) == 3
Exemplo n.º 8
0
 def test_list_models_no_args(self):
     recorder = instrument_ml_service(status=200,
                                      payload=DEFAULT_LIST_RESPONSE)
     models_page = ml.list_models()
     assert len(recorder) == 1
     assert recorder[0].method == 'GET'
     assert recorder[0].url == TestListModels._url(PROJECT_ID)
     assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
     TestListModels._check_page(models_page, 2)
     assert models_page.has_next_page
     assert models_page.next_page_token == NEXT_PAGE_TOKEN
     assert models_page.models[0] == MODEL_1
     assert models_page.models[1] == MODEL_2
Exemplo n.º 9
0
 def test_list_models_with_all_args(self):
     recorder = instrument_ml_service(status=200,
                                      payload=LAST_PAGE_LIST_RESPONSE)
     models_page = ml.list_models('display_name=displayName3',
                                  page_size=10,
                                  page_token=PAGE_TOKEN)
     assert len(recorder) == 1
     assert recorder[0].method == 'GET'
     assert recorder[0].url == (
         TestListModels._url(PROJECT_ID) +
         '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'.
         format(PAGE_TOKEN))
     assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
     assert isinstance(models_page, ml.ListModelsPage)
     assert len(models_page.models) == 1
     assert models_page.models[0] == MODEL_3
     assert not models_page.has_next_page
Exemplo n.º 10
0
    def test_list_multiple_pages(self):
        # Page 1
        recorder = instrument_ml_service(status=200,
                                         payload=DEFAULT_LIST_RESPONSE)
        page = ml.list_models()
        assert len(recorder) == 1
        assert len(page.models) == 2
        assert page.next_page_token == NEXT_PAGE_TOKEN
        assert page.has_next_page is True

        # Page 2
        recorder = instrument_ml_service(status=200,
                                         payload=LAST_PAGE_LIST_RESPONSE)
        page_2 = page.get_next_page()
        assert len(recorder) == 1
        assert len(page_2.models) == 1
        assert page_2.next_page_token == ''
        assert page_2.has_next_page is False
        assert page_2.get_next_page() is None
Exemplo n.º 11
0
    def test_list_models_paged_iteration(self):
        # Page 1
        recorder = instrument_ml_service(status=200,
                                         payload=DEFAULT_LIST_RESPONSE)
        page = ml.list_models()
        assert page.next_page_token == NEXT_PAGE_TOKEN
        assert page.has_next_page is True
        iterator = page.iterate_all()
        for index in range(2):
            model = next(iterator)
            assert model.display_name == 'displayName{0}'.format(index + 1)
        assert len(recorder) == 1

        # Page 2
        recorder = instrument_ml_service(status=200,
                                         payload=LAST_PAGE_LIST_RESPONSE)
        model = next(iterator)
        assert model.display_name == DISPLAY_NAME_3
        with pytest.raises(StopIteration):
            next(iterator)
Exemplo n.º 12
0
def list_models(filter_exp=''):
    """List the models in the project."""
    models = ml.list_models(list_filter=filter_exp).iterate_all()
    print_models(models)
Exemplo n.º 13
0
 def test_list_models_list_filter_validation(self, list_filter):
     with pytest.raises(TypeError) as excinfo:
         ml.list_models(list_filter=list_filter)
     check_error(excinfo, TypeError,
                 'List filter must be a string or None.')
Exemplo n.º 14
0
 def test_list_models_page_token_validation(self, page_token):
     with pytest.raises(TypeError) as excinfo:
         ml.list_models(page_token=page_token)
     check_error(excinfo, TypeError, 'Page token must be a string or None.')
Exemplo n.º 15
0
 def evaluate():
     app = firebase_admin.initialize_app(testutils.MockCredential(),
                                         name='no_project_id')
     with pytest.raises(ValueError):
         ml.list_models(app=app)
Exemplo n.º 16
0
def test_list_models_invalid_filter():
    invalid_filter = 'InvalidFilterParam=123'

    with pytest.raises(exceptions.InvalidArgumentError) as excinfo:
        ml.list_models(list_filter=invalid_filter)
    check_firebase_error(excinfo, 400, 'Request contains an invalid argument.')
Exemplo n.º 17
0
 def test_list_models_page_size_validation(self, page_size, exc_type,
                                           error_message):
     with pytest.raises(exc_type) as excinfo:
         ml.list_models(page_size=page_size)
     check_error(excinfo, exc_type, error_message)