def test_reload_models_in_config_file(self):

        # Replace heavy and corrupt model and reload them
        self.replace_corrupt_and_heavy_models()
        response = self.stub.ReloadConfigModels(
            service_pb2.ReloadModelsRequest())

        # Get loaded models
        response = self.stub.GetLoadedModels(service_pb2.LoadedModelsRequest())
        loaded_models = [model.name for model in response.models]

        # Check that only bad path model wasn't loaded
        configured_models = [model["name"] for model in get_config()["models"]]
        for model in configured_models:
            if model not in loaded_models:
                self.assertEqual(model, "bad_path")

        # Check haeavy and corrupt model state is LOADED
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="corrupt"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")
        request = service_pb2.ModelStatusRequest(model=model_pb2.ModelSpec(
            name="heavy"))
        response = self.stub.GetModelStatus(request)
        self.assertTrue(
            model_pb2.ModelStatus.ModelState.Name(response.status.state) ==
            "LOADED")
        self.revert_model_changes()
Пример #2
0
def create_logger():
    global logger
    logger = logging.Logger("fts")
    consoleHandler = logging.StreamHandler()
    formatter = logging.Formatter("%(asctime)s [%(levelname)s]  %(message)s")
    consoleHandler.setFormatter(formatter)
    logger.addHandler(consoleHandler)
    logger.setLevel(get_config()["logging_level"])
    def test_loaded_models(self):

        # Get loaded models
        response = self.stub.GetLoadedModels(service_pb2.LoadedModelsRequest())
        loaded_models = [model.name for model in response.models]

        # Check that only bad models weren't loaded
        configured_models = [model["name"] for model in get_config()["models"]]
        for model in configured_models:
            if model not in loaded_models:
                self.assertIn(model, self.BAD_MODELS)
            else:
                self.assertNotIn(model, self.BAD_MODELS)
Пример #4
0
def serve():
    logger = get_logger()
    logger.info("FastText server starting ...")

    # Read gRPC options
    config = get_config()
    grpc_port = config["grpc"].get("port", 50051)
    grpc_max_workers = config["grpc"].get("max_workers", 2)
    grpc_maximum_concurrent_rpcs = config["grpc"].get(
        "maximum_concurrent_rpcs", 25)
    logger.info("Concurrent workers: {}".format(grpc_max_workers))
    logger.info("gRPC queue size: {}".format(grpc_maximum_concurrent_rpcs))

    # Read gRPC channel options
    grpc_options = []
    for option in config["grpc"].get("channel_options", {}).items():
        logger.info("gRPC channel option: {}".format(option))
        grpc_options.append(option)

    # Create server
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=grpc_max_workers),
        maximum_concurrent_rpcs=grpc_maximum_concurrent_rpcs,
        options=grpc_options,
    )

    # Add servicers
    servicer = FastTextServicer()
    service_pb2_grpc.add_FastTextServicer_to_server(servicer, server)
    health_servicer = HealthServicer()
    health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)

    # Run server
    address = "[::]:{}".format(grpc_port)
    server.add_insecure_port(address)
    server.start()
    logger.info("Listening incoming connections at {}".format(address))

    # Mark the server as running using gRPC health check protocol
    serving_status = health_pb2._HEALTHCHECKRESPONSE_SERVINGSTATUS
    status_code = serving_status.values_by_name["SERVING"].number
    health_servicer.set("", status_code)
    logger.info("gRPC health check protocol: {}".format(status_code))

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(0)
        servicer = None
from collections import namedtuple
from pathlib import Path

from fts.service.exceptions import (
    FastTextException,
    MissingArgumentException,
    ModelNotLoadedException,
)
from fts.protos import model_pb2, service_pb2
from fts.utils.config import get_config
from fts.utils.logger import get_logger
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

Model = namedtuple("Model", "pb_model ft_model size state")
config = get_config()
logger = get_logger()


class ModelUpdateHandler(FileSystemEventHandler):
    def __init__(self, fasttext_service):
        self._fasttext_service = fasttext_service

    def on_created(self, event):
        self._fasttext_service._handle_file_update(Path(event.src_path))


class FastTextService(object):
    def __init__(self):

        self.load_models_in_config_file()