コード例 #1
0
 def test_operation_error(self):
     instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE)
     with pytest.raises(Exception) as excinfo:
         ml.update_model(MODEL_1)
     # The http request succeeded, the operation returned contains an update failure
     check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS,
                           OPERATION_ERROR_MSG)
コード例 #2
0
 def test_rpc_error(self):
     create_recorder = instrument_ml_service(
         status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
     with pytest.raises(Exception) as excinfo:
         ml.update_model(MODEL_1)
     check_firebase_error(excinfo, ERROR_STATUS_BAD_REQUEST,
                          ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST)
     assert len(create_recorder) == 1
コード例 #3
0
ファイル: test_ml.py プロジェクト: 7chat/e7chats_admin-python
def test_update_non_existing_model(firebase_model):
    ml.delete_model(firebase_model.model_id)

    firebase_model.tags = ['tag987']
    with pytest.raises(exceptions.NotFoundError) as excinfo:
        ml.update_model(firebase_model)
    check_operation_error(
        excinfo,
        'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name')))
コード例 #4
0
ファイル: test_ml.py プロジェクト: 7chat/e7chats_admin-python
def test_update_model(firebase_model):
    new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name')
    firebase_model.display_name = new_model_name
    updated_model = ml.update_model(firebase_model)
    check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
    check_no_model_format(updated_model)

    # Second call with same model does not cause error
    updated_model2 = ml.update_model(updated_model)
    check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
    check_no_model_format(updated_model2)
コード例 #5
0
def update_model(model_id,
                 model_file=None,
                 name=None,
                 new_tags=None,
                 remove_tags=None):
    """Update one of the project's models."""
    model = ml.get_model(model_id)

    if model_file is not None:
        # Load a tflite file and upload it to Cloud Storage
        print('Uploading to Cloud Storage...')
        model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(
            model_file)
        tflite_format = ml.TFLiteFormat(model_source=model_source)
        model.model_format = tflite_format

    if name is not None:
        model.display_name = name

    if new_tags is not None:
        model.tags = new_tags if model.tags is None else model.tags + new_tags

    if remove_tags is not None and model.tags is not None:
        model.tags = list(set(model.tags).difference(set(remove_tags)))

    updated_model = ml.update_model(model)
    ml.publish_model(updated_model.model_id)
コード例 #6
0
def publish_model_to_firebase(tflite_model_name, model_name):
    source = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_model_name)
    model_format = ml.TFLiteFormat(model_source=source)
    firebase_models = ml.list_models(
        list_filter="display_name = {0}".format(model_name)).iterate_all()
    for model in firebase_models:
        custom_model = model

    custom_model.model_format = model_format
    model_to_publish = ml.update_model(custom_model)
    ml.publish_model(model_to_publish.model_id)
コード例 #7
0
    def test_returns_locked(self):
        recorder = instrument_ml_service(
            status=[200, 200],
            payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
        expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2)
        model = ml.update_model(MODEL_1)

        assert model == expected_model
        assert len(recorder) == 2
        assert recorder[0].method == 'PATCH'
        assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
        assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
        assert recorder[1].method == 'GET'
        assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
        assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
コード例 #8
0
 def test_invalid_op_name(self, op_name):
     payload = json.dumps({'name': op_name})
     instrument_ml_service(status=200, payload=payload)
     with pytest.raises(Exception) as excinfo:
         ml.update_model(MODEL_1)
     check_error(excinfo, ValueError, 'Operation name format is invalid.')
コード例 #9
0
 def test_missing_op_name(self):
     instrument_ml_service(status=200,
                           payload=OPERATION_MISSING_NAME_RESPONSE)
     with pytest.raises(Exception) as excinfo:
         ml.update_model(MODEL_1)
     check_error(excinfo, TypeError)
コード例 #10
0
 def test_missing_display_name(self):
     with pytest.raises(Exception) as excinfo:
         ml.update_model(ml.Model.from_dict({}))
     check_error(excinfo, ValueError, 'Model must have a display name.')
コード例 #11
0
 def test_not_model(self, model):
     with pytest.raises(Exception) as excinfo:
         ml.update_model(model)
     check_error(excinfo, TypeError, 'Model must be an ml.Model.')
コード例 #12
0
 def test_malformed_operation(self):
     instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
     with pytest.raises(Exception) as excinfo:
         ml.update_model(MODEL_1)
     check_error(excinfo, exceptions.UnknownError,
                 'Internal Error: Malformed Operation.')
コード例 #13
0
 def test_immediate_done(self):
     instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE)
     model = ml.update_model(MODEL_1)
     assert model == CREATED_UPDATED_MODEL_1