def test_correct(self):
        self.replace_corrupt_and_heavy_models()
        request = service_pb2.LoadModelsRequest(models=(
            model_pb2.ModelSpec(name="corrupt",
                                base_path="test/resources/models/corrupt"),
            model_pb2.ModelSpec(name="heavy",
                                base_path="test/resources/models/heavy"),
        ))
        response = self.stub.LoadModels(request)
        self.assertTrue(response.success)

        # Check both models' state is LOADED
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="corrupt"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="heavy"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")
        self.revert_model_changes()
    def test_reload_models_in_config_file(self):

        # Replace heavy and corrupt model and reload them
        self.replace_corrupt_and_heavy_models()
        response = self.stub.ReloadConfigModels(
            service_pb2.ReloadModelsRequest())

        # Get loaded models
        response = self.stub.GetLoadedModels(service_pb2.LoadedModelsRequest())
        loaded_models = [model.name for model in response.models]

        # Check that only bad path model wasn't loaded
        configured_models = [model["name"] for model in get_config()["models"]]
        for model in configured_models:
            if model not in loaded_models:
                self.assertEqual(model, "bad_path")

        # Check haeavy and corrupt model state is LOADED
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="corrupt"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="heavy"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")
        self.revert_model_changes()
 def test_available(self):
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="heavy"))
     response = self.stub.GetModelStatus(request)
     self.assertTrue(
         model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
         "AVAILABLE")
 def test_failed(self):
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="corrupt"))
     response = self.stub.GetModelStatus(request)
     self.assertTrue(
         model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
         "FAILED")
 def test_unknown(self):
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="foo"))
     response = self.stub.GetModelStatus(request)
     self.assertTrue(
         model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
         "UNKNOWN")
 def test_update_bad_path_model(self):
     copytree(self.CORRECT_MODEL_PATH, self.BAD_PATH)
     time.sleep(3)
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="bad_path"))
     response = self.stub.GetModelStatus(request)
     self.assertTrue(
         model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
         "LOADED")
     rmtree(self.BAD_PATH)
 def test_update_corrupt_model(self):
     copytree((self.CORRECT_MODEL_PATH / "1"),
              (self.CORRUPT_MODEL_PATH / "2"))
     time.sleep(3)
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="corrupt"))
     response = self.stub.GetModelStatus(request)
     self.assertTrue(
         model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
         "LOADED")
     rmtree(self.CORRUPT_MODEL_PATH / "2")
 def test_update_loaded_model(self):
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="correct"))
     previous_response = self.stub.GetModelStatus(request)
     self.assertEqual(previous_response.status.version, 1)
     copytree((self.CORRECT_MODEL_PATH / "1"),
              (self.CORRECT_MODEL_PATH / "2"))
     time.sleep(3)
     later_response = self.stub.GetModelStatus(request)
     self.assertEqual(later_response.status.version, 2)
     rmtree(self.CORRECT_MODEL_PATH / "2")
 def test_highest_version(self):
     request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
         name="correct"))
     previous_response = self.stub.GetModelStatus(request)
     self.assertEqual(previous_response.status.version, 1)
     copytree((self.CORRECT_MODEL_PATH / "1"),
              (self.CORRECT_MODEL_PATH / "2"))
     self.stub.ReloadConfigModels(service_pb2.ReloadModelsRequest())
     later_response = self.stub.GetModelStatus(request)
     self.assertEqual(later_response.status.version, 2)
     rmtree(self.CORRECT_MODEL_PATH / "2")
    def test_corrupt_model(self):
        model_spec = model_pb2.ModelSpec(
            name="corrupt", base_path="test/resources/models/corrupt")
        request = service_pb2.LoadModelsRequest(models=(model_spec, ))
        response = self.stub.LoadModels(request)
        self.assertFalse(response.success)

        # Check model's state is FAILED
        request = service_pb2.ModelStatusRequest(model=model_spec)
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "FAILED")
    def test_reloading_new_model(self):

        # Configure a new model
        config_path = os.environ["SERVICE_CONFIG_PATH"]
        with open(config_path, "r+") as f:
            config = yaml.load(f.read())
            config["models"].append({"base_path": "new", "name": "new"})
            yaml.dump(config, f)

        # Make the new model available and reload models
        self.duplicate_correct_model()

        # Check that the new model is loaded
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="new"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")

        # Revert changes
        self.revert_config_changes()
        self.revert_model_changes()