def test_correct(self): self.replace_corrupt_and_heavy_models() request = service_pb2.LoadModelsRequest(models=( model_pb2.ModelSpec(name="corrupt", base_path="test/resources/models/corrupt"), model_pb2.ModelSpec(name="heavy", base_path="test/resources/models/heavy"), )) response = self.stub.LoadModels(request) self.assertTrue(response.success) # Check both models' state is LOADED request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="corrupt")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="heavy")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") self.revert_model_changes()
def test_reload_models_in_config_file(self): # Replace heavy and corrupt model and reload them self.replace_corrupt_and_heavy_models() response = self.stub.ReloadConfigModels( service_pb2.ReloadModelsRequest()) # Get loaded models response = self.stub.GetLoadedModels(service_pb2.LoadedModelsRequest()) loaded_models = [model.name for model in response.models] # Check that only bad path model wasn't loaded configured_models = [model["name"] for model in get_config()["models"]] for model in configured_models: if model not in loaded_models: self.assertEqual(model, "bad_path") # Check haeavy and corrupt model state is LOADED request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="corrupt")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="heavy")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") self.revert_model_changes()
def test_available(self): request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="heavy")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "AVAILABLE")
def test_failed(self): request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="corrupt")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "FAILED")
def test_unknown(self): request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="foo")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "UNKNOWN")
def test_update_bad_path_model(self): copytree(self.CORRECT_MODEL_PATH, self.BAD_PATH) time.sleep(3) request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="bad_path")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") rmtree(self.BAD_PATH)
def test_update_corrupt_model(self): copytree((self.CORRECT_MODEL_PATH / "1"), (self.CORRUPT_MODEL_PATH / "2")) time.sleep(3) request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="corrupt")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") rmtree(self.CORRUPT_MODEL_PATH / "2")
def test_update_loaded_model(self): request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="correct")) previous_response = self.stub.GetModelStatus(request) self.assertEqual(previous_response.status.version, 1) copytree((self.CORRECT_MODEL_PATH / "1"), (self.CORRECT_MODEL_PATH / "2")) time.sleep(3) later_response = self.stub.GetModelStatus(request) self.assertEqual(later_response.status.version, 2) rmtree(self.CORRECT_MODEL_PATH / "2")
def test_highest_version(self): request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="correct")) previous_response = self.stub.GetModelStatus(request) self.assertEqual(previous_response.status.version, 1) copytree((self.CORRECT_MODEL_PATH / "1"), (self.CORRECT_MODEL_PATH / "2")) self.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest()) later_response = self.stub.GetModelStatus(request) self.assertEqual(later_response.status.version, 2) rmtree(self.CORRECT_MODEL_PATH / "2")
def test_corrupt_model(self): model_spec = model_pb2.ModelSpec( name="corrupt", base_path="test/resources/models/corrupt") request = service_pb2.LoadModelsRequest(models=(model_spec, )) response = self.stub.LoadModels(request) self.assertFalse(response.success) # Check model's state is FAILED request = service_pb2.ModelStatusRequest(model=model_spec) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "FAILED")
def test_reloading_new_model(self): # Configure a new model config_path = os.environ["SERVICE_CONFIG_PATH"] with open(config_path, "r+") as f: config = yaml.load(f.read()) config["models"].append({"base_path": "new", "name": "new"}) yaml.dump(config, f) # Make the new model available and reload models self.duplicate_correct_model() # Check that the new model is loaded request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec( name="new")) response = self.stub.GetModelStatus(request) self.assertTrue( model_pb2.ModelStatus.ModelState.Name(response.status.state) == "LOADED") # Revert changes self.revert_config_changes() self.revert_model_changes()