def replace_corrupt_and_heavy_models(cls): cls.backup_model_changes() rmtree(cls.CORRUPT_MODEL_PATH) copytree(cls.CORRECT_MODEL_PATH, cls.CORRUPT_MODEL_PATH) rmtree(cls.HEAVY_MODEL_PATH) copytree(cls.CORRECT_MODEL_PATH, cls.HEAVY_MODEL_PATH) cls.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest())
def test_reload_models_in_config_file(self): # Replace heavy and corrupt model and reload them self.replace_corrupt_and_heavy_models() response = self.stub.ReloadConfigModels( service_pb2.ReloadModelsRequest()) # Get loaded models response = self.stub.GetLoadedModels(service_pb2.LoadedModelsRequest()) loaded_models = [model.name for model in response.models] # Check that only bad path model wasn't loaded configured_models = [model["name"] for model in get_config()["models"]] for model in configured_models: if model not in loaded_models: self.assertEqual(model, "bad_path") # Check haeavy and corrupt model state is LOADED request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="corrupt")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="heavy")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") self.revert_model_changes()
def test_highest_version(self): request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="correct")) previous_response = self.stub.GetModelStatus(request) self.assertEqual(previous_response.status.version, 1) copytree((self.CORRECT_MODEL_PATH / "1"), (self.CORRECT_MODEL_PATH / "2")) self.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest()) later_response = self.stub.GetModelStatus(request) self.assertEqual(later_response.status.version, 2) rmtree(self.CORRECT_MODEL_PATH / "2")
def revert_model_changes(cls): rmtree(cls.MODELS_DIR) copytree(cls.MODELS_BACKUP_DIR, cls.MODELS_DIR) cls.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest())
def duplicate_correct_model(cls): cls.backup_model_changes() copytree(cls.CORRECT_MODEL_PATH, cls.NEW_MODEL_PATH) cls.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest())
def revert_config_changes(cls): os.remove(cls.CONFIG_PATH) copy(cls.CONFIG_BACKUP_PATH, cls.CONFIG_PATH) cls.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest())