def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
def test_rpc_error(self): create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_firebase_error(excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST) assert len(create_recorder) == 1
def test_update_non_existing_model(firebase_model): ml.delete_model(firebase_model.model_id) firebase_model.tags = ['tag987'] with pytest.raises(exceptions.NotFoundError) as excinfo: ml.update_model(firebase_model) check_operation_error( excinfo, 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name')))
def test_update_model(firebase_model): new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name') firebase_model.display_name = new_model_name updated_model = ml.update_model(firebase_model) check_model(updated_model, NAME_ONLY_ARGS_UPDATED) check_no_model_format(updated_model) # Second call with same model does not cause error updated_model2 = ml.update_model(updated_model) check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) check_no_model_format(updated_model2)
def update_model(model_id, model_file=None, name=None, new_tags=None, remove_tags=None): """Update one of the project's models.""" model = ml.get_model(model_id) if model_file is not None: # Load a tflite file and upload it to Cloud Storage print('Uploading to Cloud Storage...') model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( model_file) tflite_format = ml.TFLiteFormat(model_source=model_source) model.model_format = tflite_format if name is not None: model.display_name = name if new_tags is not None: model.tags = new_tags if model.tags is None else model.tags + new_tags if remove_tags is not None and model.tags is not None: model.tags = list(set(model.tags).difference(set(remove_tags))) updated_model = ml.update_model(model) ml.publish_model(updated_model.model_id)
def publish_model_to_firebase(tflite_model_name, model_name): source = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_model_name) model_format = ml.TFLiteFormat(model_source=source) firebase_models = ml.list_models( list_filter="display_name = {0}".format(model_name)).iterate_all() for model in firebase_models: custom_model = model custom_model.model_format = model_format model_to_publish = ml.update_model(custom_model) ml.publish_model(model_to_publish.model_id)
def test_returns_locked(self): recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = ml.update_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.')
def test_missing_op_name(self): instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_error(excinfo, TypeError)
def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: ml.update_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.')
def test_not_model(self, model): with pytest.raises(Exception) as excinfo: ml.update_model(model) check_error(excinfo, TypeError, 'Model must be an ml.Model.')
def test_malformed_operation(self): instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = ml.update_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1