Exemplo n.º 1
0
    def initialize_impl(
        self,
        project_dir: str,
        client: Union[PythonClient, TensorFlowClient, ONNXClient],
        job_spec: Optional[dict] = None,
    ):
        """
        Initialize predictor class as provided by the user.

        job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None.
        """

        # build args
        class_impl = self.class_impl(project_dir)
        constructor_args = inspect.getfullargspec(class_impl.__init__).args
        config = deepcopy(self.config)
        args = {}
        if job_spec is not None and job_spec.get("config") is not None:
            util.merge_dicts_in_place_overwrite(config, job_spec["config"])
        if "config" in constructor_args:
            args["config"] = config
        if "job_spec" in constructor_args:
            args["job_spec"] = job_spec

        # initialize predictor class
        try:
            if self.type == PythonPredictorType:
                if _are_models_specified(self.api_spec):
                    args["python_client"] = client
                    initialized_impl = class_impl(**args)
                    client.set_load_method(initialized_impl.load_model)
                else:
                    initialized_impl = class_impl(**args)
            if self.type in [
                    TensorFlowPredictorType, TensorFlowNeuronPredictorType
            ]:
                args["tensorflow_client"] = client
                initialized_impl = class_impl(**args)
            if self.type == ONNXPredictorType:
                args["onnx_client"] = client
                initialized_impl = class_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e
        finally:
            refresh_logger()

        # initialize the crons if models have been specified and if the API kind is RealtimeAPI
        if _are_models_specified(
                self.api_spec) and self.api_spec["kind"] == "RealtimeAPI":
            if not self.multiple_processes and self.caching_enabled:
                self.crons += [
                    ModelTreeUpdater(
                        interval=10,
                        api_spec=self.api_spec,
                        tree=self.models_tree,
                        ondisk_models_dir=self.model_dir,
                    ),
                    ModelsGC(
                        interval=10,
                        api_spec=self.api_spec,
                        models=self.models,
                        tree=self.models_tree,
                    ),
                ]

            if not self.caching_enabled and self.type in [
                    PythonPredictorType, ONNXPredictorType
            ]:
                self.crons += [
                    FileBasedModelsGC(interval=10,
                                      models=self.models,
                                      download_dir=self.model_dir)
                ]

        for cron in self.crons:
            cron.start()

        return initialized_impl
Exemplo n.º 2
0
    def initialize_impl(
        self,
        project_dir: str,
        client: Union[PythonClient, TensorFlowClient],
        metrics_client: DogStatsd,
        job_spec: Optional[Dict[str, Any]] = None,
        proto_module_pb2: Optional[Any] = None,
    ):
        """
        Initialize predictor class as provided by the user.

        job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None.
        proto_module_pb2 is a module of the compiled proto when grpc is enabled for the "RealtimeAPI" kind. Otherwise, it's None.

        Can raise UserRuntimeException/UserException/CortexException.
        """

        # build args
        class_impl = self.class_impl(project_dir)
        constructor_args = inspect.getfullargspec(class_impl.__init__).args
        config = deepcopy(self.config)
        args = {}
        if job_spec is not None and job_spec.get("config") is not None:
            util.merge_dicts_in_place_overwrite(config, job_spec["config"])
        if "config" in constructor_args:
            args["config"] = config
        if "job_spec" in constructor_args:
            args["job_spec"] = job_spec
        if "metrics_client" in constructor_args:
            args["metrics_client"] = metrics_client
        if "proto_module_pb2" in constructor_args:
            args["proto_module_pb2"] = proto_module_pb2

        # initialize predictor class
        try:
            if self.type == PythonPredictorType:
                if are_models_specified(self.api_spec):
                    args["python_client"] = client
                    # set load method to enable the use of the client in the constructor
                    # setting/getting from self in load_model won't work because self will be set to None
                    client.set_load_method(lambda model_path: class_impl.
                                           load_model(None, model_path))
                    initialized_impl = class_impl(**args)
                    client.set_load_method(initialized_impl.load_model)
                else:
                    initialized_impl = class_impl(**args)
            if self.type in [
                    TensorFlowPredictorType, TensorFlowNeuronPredictorType
            ]:
                args["tensorflow_client"] = client
                initialized_impl = class_impl(**args)
        except Exception as e:
            raise UserRuntimeException(self.path, "__init__", str(e)) from e

        # initialize the crons if models have been specified and if the API kind is RealtimeAPI
        if are_models_specified(
                self.api_spec) and self.api_spec["kind"] == "RealtimeAPI":
            if not self.multiple_processes and self.caching_enabled:
                self.crons += [
                    ModelTreeUpdater(
                        interval=10,
                        api_spec=self.api_spec,
                        tree=self.models_tree,
                        ondisk_models_dir=self.model_dir,
                    ),
                    ModelsGC(
                        interval=10,
                        api_spec=self.api_spec,
                        models=self.models,
                        tree=self.models_tree,
                    ),
                ]

            if not self.caching_enabled and self.type == PythonPredictorType:
                self.crons += [
                    FileBasedModelsGC(interval=10,
                                      models=self.models,
                                      download_dir=self.model_dir)
                ]

        for cron in self.crons:
            cron.start()

        return initialized_impl