Exemplo n.º 1
0
    def pack(self, data):  # pylint:disable=arguments-differ
        try:
            import tensorflow as tf
        except ImportError:
            raise MissingDependencyException(
                "Tensorflow package is required to use KerasModelArtifact. BentoML "
                "currently only support using Keras with Tensorflow backend.")

        if isinstance(data, dict):
            model = data['model']
            custom_objects = (data['custom_objects'] if 'custom_objects'
                              in data else self.custom_objects)
        else:
            model = data
            custom_objects = self.custom_objects

        if not isinstance(model, tf.keras.models.Model):
            error_msg = (
                "KerasModelArtifact#pack expects model argument to be type: "
                "keras.engine.network.Network, tf.keras.models.Model, or their "
                "aliases, instead got type: {}".format(type(model)))
            try:
                import keras

                if not isinstance(model, keras.engine.network.Network):
                    raise InvalidArgument(error_msg)
                else:
                    self._keras_module_name = keras.__name__
            except ImportError:
                raise InvalidArgument(error_msg)

        self.bind_keras_backend_session()
        model._make_predict_function()
        return _KerasModelArtifactWrapper(self, model, custom_objects)
Exemplo n.º 2
0
    def _load_from_dict(self, model):
        if not model.get('model'):
            raise InvalidArgument(
                "'model' key is not found in the dictionary. "
                "Expecting a dictionary with keys 'model', 'tokenizer' and 'embedder'"
            )

        if not model.get('tokenizer'):
            raise InvalidArgument(
                "'tokenizer' key is not found in the dictionary. "
                "Expecting a dictionary with keys 'model', 'tokenizer' and 'embedder'"
            )

        model_class = str(type(model.get('model')).__module__)
        tokenizer_class = str(type(model.get('tokenizer')).__module__)

        # if either model or tokenizer is not a property of the transformers package
        if not model_class.startswith('transformers'):
            raise InvalidArgument('Expecting a transformers model object '
                                  'but got {}'.format(type(
                                      model.get('model'))))
        if not tokenizer_class.startswith('transformers'):
            raise InvalidArgument('Expecting a transformers tokenizer object '
                                  'but got {}'.format(
                                      type(model.get('tokenizer'))))

        self._model = model
Exemplo n.º 3
0
    def _load_from_dict(self, model):
        if not model.get("model"):
            raise InvalidArgument(
                " 'model' key is not found in the dictionary."
                " Expecting a dictionary of with keys 'model' and 'tokenizer'"
            )
        if not model.get("tokenizer"):
            raise InvalidArgument(
                "'tokenizer' key is not found in the dictionary. "
                "Expecting a dictionary of with keys 'model' and 'tokenizer'"
            )

        model_class = str(type(model.get("model")).__module__)
        tokenizer_class = str(type(model.get("tokenizer")).__module__)
        # if either model or tokenizer is not an object of transformers
        if not model_class.startswith("transformers"):
            raise InvalidArgument(
                "Expecting a transformers model object but object passed is {}".format(
                    type(model.get("model"))
                )
            )
        if not tokenizer_class.startswith("transformers"):
            raise InvalidArgument(
                "Expecting a transformers model object but object passed is {}".format(
                    type(model.get("tokenizer"))
                )
            )
        # success
        self._model = model
Exemplo n.º 4
0
    def decorator(func):
        _api_name = func.__name__ if api_name is None else api_name
        _api_doc = ((func.__doc__ or DEFAULT_API_DOC).strip()
                    if api_doc is None else api_doc)

        if input is None:
            # support bentoml<=0.7
            if not args or not (inspect.isclass(args[0])
                                and issubclass(args[0], BentoHandler)):
                raise InvalidArgument(
                    "BentoService @api decorator first parameter must "
                    "be class derived from bentoml.handlers.BentoHandler")

            handler = args[0](*args[1:], **kwargs)
        else:
            handler = input

        setattr(func, "_is_api", True)
        setattr(func, "_handler", handler)
        setattr(func, "_output_adapter", output)
        if not isidentifier(_api_name):
            raise InvalidArgument(
                "Invalid API name: '{}', a valid identifier must contains only letters,"
                " numbers, underscores and not starting with a number.".format(
                    _api_name))
        setattr(func, "_api_name", _api_name)
        setattr(func, "_api_doc", _api_doc)

        return func
Exemplo n.º 5
0
def _validate_labels(labels):
    """
    Validate labels key value format is:
        * Between 3 and 63 characters
        * Consist of alphanumeric, dash (-), period (.), and underscore (_)
        * Start and end with alphanumeric
    Args:
        labels: Dictionary

    Returns:
    Raise:
        InvalidArgument
    """
    if not isinstance(labels, dict):
        raise InvalidArgument('BentoService labels must be a dictionary')

    pattern = re.compile("^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$")
    for key in labels:
        if (
            not (63 >= len(key) >= 3)
            or not (63 >= len(labels[key]) >= 3)
            or not pattern.match(key)
            or not pattern.match(labels[key])
        ):
            raise InvalidArgument(
                f'Invalide label {key}:{labels[key]}. Valid label key and value must '
                f'be between 3 to 63 characters and must be begin and end with '
                f'an alphanumeric character ([a-z0-9A-Z]) with dashes (-), '
                f'underscores (_), and dots (.).'
            )
Exemplo n.º 6
0
 def pack(self, path_or_model, metadata=None):  # pylint:disable=arguments-differ
     if _is_pytorch_lightning_model_file_like(path_or_model):
         logger.info(
             'PytorchLightningArtifact is packing a saved torchscript module '
             'from path'
         )
         self._model_path = path_or_model
     else:
         try:
             from pytorch_lightning.core.lightning import LightningModule
         except ImportError:
             raise InvalidArgument(
                 '"pytorch_lightning.lightning.LightningModule" model is required '
                 'to pack a PytorchLightningModelArtifact'
             )
         if isinstance(path_or_model, LightningModule):
             logger.info(
                 'PytorchLightningArtifact is packing a pytorch lightning '
                 'model instance as torchscript module'
             )
             self._model = path_or_model.to_torchscript()
         else:
             raise InvalidArgument(
                 'a LightningModule model is required to pack a '
                 'PytorchLightningModelArtifact'
             )
     return self
Exemplo n.º 7
0
    def pack(self, model, metadata=None, opts=None, update=False):
        """
        The method is used for packing trained model instances with a BentoService instance and make it ready for save.
        
        Parameters:
            model: 
                A path to the trained model directory or a dictionary as {'model':model_instance}
            metadata: 
                Optional - dict of args used to instantiate the target model artifact to be packed
            opts: 
                Optional - dict of args to temporary overwrite metadata if update param is set to True.
                These args won't be saved with BentoService instance
            update: 
                Optional - If set to True the metadata args will be temporary overwritten with matching args set in opts
        
        returns: 
            This BentoService instance
        """
        if isinstance(model, str):
            if os.path.isdir(model):
                self._load_from_directory(model, metadata, opts, update)
            else:
                raise InvalidArgument(
                    'Expecting a path to the model directory')
        elif isinstance(model, dict):
            self._load_from_dict(model, metadata)
        else:
            raise InvalidArgument(
                'Expecting model to be a path to the model directory or a dict'
            )

        return self
Exemplo n.º 8
0
    def __init__(self, spec, path_or_model_proto):
        """
        :param spec: parent OnnxModelArtifact
        :param path_or_model_proto: .onnx file path or onnx.ModelProto object
        """
        super(_OnnxModelArtifactWrapper, self).__init__(spec)

        self._inference_session = None

        self._onnx_model_path = None
        self._model_proto = None
        if _is_onnx_model_file(path_or_model_proto):
            self._onnx_model_path = path_or_model_proto
        else:
            try:
                import onnx

                if isinstance(path_or_model_proto, onnx.ModelProto):
                    self._model_proto = path_or_model_proto
                else:
                    raise InvalidArgument(
                        'onnx.ModelProto or a .onnx model file path is required to '
                        'pack an OnnxModelArtifact')
            except ImportError:
                raise InvalidArgument(
                    'onnx.ModelProto or a .onnx model file path is required to pack '
                    'an OnnxModelArtifact')

        assert self._onnx_model_path or self._model_proto, (
            "Either self._onnx_model_path or self._model_proto has to be initilaized "
            "after initializing _OnnxModelArtifactWrapper")
Exemplo n.º 9
0
    def pack(
        self, path_or_model_proto, metadata=None
    ):  # pylint:disable=arguments-renamed
        if _is_onnx_model_file(path_or_model_proto):
            self._onnx_model_path = path_or_model_proto
        else:
            try:
                import onnx

                if isinstance(path_or_model_proto, onnx.ModelProto):
                    self._model_proto = path_or_model_proto
                else:
                    raise InvalidArgument(
                        "onnx.ModelProto or a .onnx model file path is required to "
                        "pack an OnnxModelArtifact"
                    )
            except ImportError:
                raise InvalidArgument(
                    "onnx.ModelProto or a .onnx model file path is required to pack "
                    "an OnnxModelArtifact"
                )

        assert self._onnx_model_path or self._model_proto, (
            "Either self._onnx_model_path or self._model_proto has to be initialized "
            "after initializing _OnnxModelArtifactWrapper"
        )

        return self
Exemplo n.º 10
0
    def decorator(func):
        _api_name = func.__name__ if api_name is None else api_name
        _api_doc = (
            (func.__doc__ or DEFAULT_API_DOC).strip() if api_doc is None else api_doc
        )
        _mb_max_batch_size = (
            DEFAULT_MAX_BATCH_SIZE if mb_max_batch_size is None else mb_max_batch_size
        )
        _mb_max_latency = (
            DEFAULT_MAX_LATENCY if mb_max_latency is None else mb_max_latency
        )

        if input is None:
            # support bentoml<=0.7
            if not args or not (
                inspect.isclass(args[0]) and issubclass(args[0], BaseInputAdapter)
            ):
                raise InvalidArgument(
                    "BentoService @api decorator first parameter must "
                    "be class derived from bentoml.adapters.BaseInputAdapter"
                )

            handler = args[0](*args[1:], output_adapter=output, **kwargs)
        else:
            assert isinstance(input, BaseInputAdapter), (
                "API input parameter must be an instance of any classes inherited from "
                "bentoml.adapters.BaseInputAdapter"
            )
            handler = input
            handler._output_adapter = output

        setattr(func, "_is_api", True)
        setattr(func, "_handler", handler)
        if not isidentifier(_api_name):
            raise InvalidArgument(
                "Invalid API name: '{}', a valid identifier must contains only letters,"
                " numbers, underscores and not starting with a number.".format(
                    _api_name
                )
            )

        if _api_name in BENTOML_RESERVED_API_NAMES:
            raise InvalidArgument(
                "Reserved API name: '{}' is reserved for infra endpoints".format(
                    _api_name
                )
            )
        setattr(func, "_api_name", _api_name)
        setattr(func, "_api_doc", _api_doc)

        setattr(func, "_mb_max_batch_size", _mb_max_batch_size)
        setattr(func, "_mb_max_latency", _mb_max_latency)

        return func
Exemplo n.º 11
0
def validate_inference_api_name(api_name):
    if not isidentifier(api_name):
        raise InvalidArgument(
            "Invalid API name: '{}', a valid identifier must contains only letters,"
            " numbers, underscores and not starting with a number.".format(
                api_name))

    if api_name in BENTOML_RESERVED_API_NAMES:
        raise InvalidArgument(
            "Reserved API name: '{}' is reserved for infra endpoints".format(
                api_name))
Exemplo n.º 12
0
def validate_inference_api_route(route: str):
    if re.findall(
            r"[?#]+|^(//)|^:", route
    ):  # contains '?' or '#' OR  start with '//' OR start with ':'
        # https://tools.ietf.org/html/rfc3986#page-22
        raise InvalidArgument(
            "The path {} contains illegal url characters".format(route))
    if route in BENTOML_RESERVED_API_NAMES:
        raise InvalidArgument(
            "Reserved API route: '{}' is reserved for infra endpoints".format(
                route))
Exemplo n.º 13
0
    def pack(self, model, opts=None, update=False):
        if isinstance(model, str):
            if os.path.isdir(model):
                self._load_from_directory(model)
            else:
                raise InvalidArgument(
                    'Expecting a path to the model directory')
        elif isinstance(model, dict):
            self._load_from_dict(model)

        if self._model is None:
            raise InvalidArgument('Expecting a pathor dict ')

        return self
Exemplo n.º 14
0
    def pack(self, model, opts=None, update=False):
        if isinstance(model, str):
            if os.path.isdir(model):
                self._load_from_directory(model)
            else:
                raise InvalidArgument(
                    'model should be the path name of a directory')
        elif isinstance(model, dict):
            self._load_from_dict(model)
        else:
            raise InvalidArgument(
                "Expecting a pretrained model path or dictionary of format "
                "{'embedder': <model object>}")

        return self
Exemplo n.º 15
0
    def add(self, deployment_pb):
        try:
            deployment_spec = deployment_pb.spec
            deployment_spec.aws_lambda_operator_config.region = (
                deployment_spec.aws_lambda_operator_config.region
                or get_default_aws_region())
            if not deployment_spec.aws_lambda_operator_config.region:
                raise InvalidArgument('AWS region is missing')

            bento_pb = self.yatai_service.GetBento(
                GetBentoRequest(
                    bento_name=deployment_spec.bento_name,
                    bento_version=deployment_spec.bento_version,
                ))
            if bento_pb.bento.uri.type not in (BentoUri.LOCAL, BentoUri.S3):
                raise BentoMLException(
                    'BentoML currently not support {} repository'.format(
                        BentoUri.StorageType.Name(bento_pb.bento.uri.type)))

            return self._add(deployment_pb, bento_pb, bento_pb.bento.uri.uri)
        except BentoMLException as error:
            deployment_pb.state.state = DeploymentState.ERROR
            deployment_pb.state.error_message = f'Error: {str(error)}'
            return ApplyDeploymentResponse(status=error.status_proto,
                                           deployment=deployment_pb)
Exemplo n.º 16
0
    def update(self, deployment_pb, previous_deployment):
        try:
            ensure_sam_available_or_raise()
            ensure_docker_available_or_raise()
            deployment_spec = deployment_pb.spec
            ec2_deployment_config = deployment_spec.aws_ec2_operator_config
            ec2_deployment_config.region = (ec2_deployment_config.region
                                            or get_default_aws_region())
            if not ec2_deployment_config.region:
                raise InvalidArgument("AWS region is missing")

            bento_pb = self.yatai_service.GetBento(
                GetBentoRequest(
                    bento_name=deployment_spec.bento_name,
                    bento_version=deployment_spec.bento_version,
                ))

            if bento_pb.bento.uri.type not in (BentoUri.LOCAL, BentoUri.S3):
                raise BentoMLException(
                    "BentoML currently not support {} repository".format(
                        BentoUri.StorageType.Name(bento_pb.bento.uri.type)))

            return self._update(
                deployment_pb,
                previous_deployment,
                bento_pb.bento.uri.uri,
                ec2_deployment_config.region,
            )
        except BentoMLException as error:
            deployment_pb.state.state = DeploymentState.ERROR
            deployment_pb.state.error_message = f"Error: {str(error)}"
            return ApplyDeploymentResponse(status=error.status_proto,
                                           deployment=deployment_pb)
Exemplo n.º 17
0
    def pack(
        self,
        easyocr_model,
        metadata=None,
        recog_network="english_g2",
        lang_list=None,
        detect_model="craft_mlt_25k",
        gpu=False,
    ):  # pylint:disable=arguments-differ
        try:
            import easyocr  # noqa # pylint: disable=unused-import

            assert easyocr.__version__ >= "1.3"
        except ImportError:
            raise MissingDependencyException(
                "easyocr>=1.3 package is required to use EasyOCRArtifact")

        if not (type(easyocr_model) is easyocr.easyocr.Reader):
            raise InvalidArgument(
                "'easyocr_model' must be of type  easyocr.easyocr.Reader")

        if not lang_list:
            lang_list = ['en']
        self._model = easyocr_model
        self._detect_model = detect_model
        self._recog_network = recog_network
        self._gpu = gpu
        self._model_params = {
            "lang_list": lang_list,
            "recog_network": recog_network,
            "gpu": gpu,
        }

        return self
Exemplo n.º 18
0
    def decorator(func):
        _api_name = func.__name__ if api_name is None else api_name
        validate_inference_api_name(_api_name)

        if input is None:
            # Raise error when input adapter class passed without instantiation
            if not args or not (inspect.isclass(args[0])
                                and issubclass(args[0], BaseInputAdapter)):
                raise InvalidArgument(
                    "BentoService @api decorator first parameter must "
                    "be an instance of a class derived from "
                    "bentoml.adapters.BaseInputAdapter ")

            # noinspection PyPep8Naming
            InputAdapter = args[0]
            input_adapter = InputAdapter(*args[1:], **kwargs)
            output_adapter = DefaultOutput()
        else:
            assert isinstance(input, BaseInputAdapter), (
                "API input parameter must be an instance of a class derived from "
                "bentoml.adapters.BaseInputAdapter")
            input_adapter = input
            output_adapter = output or DefaultOutput()

        setattr(func, "_is_api", True)
        setattr(func, "_input_adapter", input_adapter)
        setattr(func, "_output_adapter", output_adapter)
        setattr(func, "_api_name", _api_name)
        setattr(func, "_api_doc", api_doc)
        setattr(func, "_mb_max_batch_size", mb_max_batch_size)
        setattr(func, "_mb_max_latency", mb_max_latency)
        setattr(func, "_batch", batch)

        return func
Exemplo n.º 19
0
    def decorator(func):
        _api_name = func.__name__ if api_name is None else api_name
        validate_inference_api_name(_api_name)

        if input is None:
            # support bentoml<=0.7
            if not args or not (inspect.isclass(args[0])
                                and issubclass(args[0], BaseInputAdapter)):
                raise InvalidArgument(
                    "BentoService @api decorator first parameter must "
                    "be class derived from bentoml.adapters.BaseInputAdapter")

            handler = args[0](*args[1:], output_adapter=output, **kwargs)
        else:
            assert isinstance(input, BaseInputAdapter), (
                "API input parameter must be an instance of any classes inherited from "
                "bentoml.adapters.BaseInputAdapter")
            handler = input
            handler._output_adapter = output

        setattr(func, "_is_api", True)
        setattr(func, "_handler", handler)
        setattr(func, "_api_name", _api_name)
        setattr(func, "_api_doc", api_doc)
        setattr(func, "_mb_max_batch_size", mb_max_batch_size)
        setattr(func, "_mb_max_latency", mb_max_latency)

        return func
Exemplo n.º 20
0
    def __setitem__(self, key, artifact):
        if key != artifact.spec.name:
            raise InvalidArgument(
                "Must use Artifact name as key, {} not equal to {}".format(
                    key, artifact.spec.name))

        self.add(artifact)
Exemplo n.º 21
0
    def delete(self, deployment_pb):
        try:
            logger.debug('Deleting AWS Lambda deployment')

            deployment_spec = deployment_pb.spec
            lambda_deployment_config = deployment_spec.aws_lambda_operator_config
            lambda_deployment_config.region = (lambda_deployment_config.region
                                               or get_default_aws_region())
            if not lambda_deployment_config.region:
                raise InvalidArgument('AWS region is missing')

            cf_client = boto3.client('cloudformation',
                                     lambda_deployment_config.region)
            stack_name = generate_aws_compatible_string(
                deployment_pb.namespace, deployment_pb.name)
            if deployment_pb.state.info_json:
                deployment_info_json = json.loads(
                    deployment_pb.state.info_json)
                bucket_name = deployment_info_json.get('s3_bucket')
                if bucket_name:
                    cleanup_s3_bucket_if_exist(bucket_name,
                                               lambda_deployment_config.region)

            logger.debug(
                'Deleting AWS CloudFormation: %s that includes Lambda function '
                'and related resources',
                stack_name,
            )
            cf_client.delete_stack(StackName=stack_name)
            return DeleteDeploymentResponse(status=Status.OK())

        except BentoMLException as error:
            return DeleteDeploymentResponse(status=error.status_proto)
Exemplo n.º 22
0
    def add(self, deployment_pb):
        try:
            deployment_spec = deployment_pb.spec
            sagemaker_config = deployment_spec.sagemaker_operator_config
            sagemaker_config.region = (sagemaker_config.region
                                       or get_default_aws_region())
            if not sagemaker_config.region:
                raise InvalidArgument('AWS region is missing')

            ensure_docker_available_or_raise()
            if sagemaker_config is None:
                raise YataiDeploymentException(
                    'Sagemaker configuration is missing.')

            bento_pb = self.yatai_service.GetBento(
                GetBentoRequest(
                    bento_name=deployment_spec.bento_name,
                    bento_version=deployment_spec.bento_version,
                ))
            if bento_pb.bento.uri.type not in (BentoUri.LOCAL, BentoUri.S3):
                raise BentoMLException(
                    'BentoML currently not support {} repository'.format(
                        BentoUri.StorageType.Name(bento_pb.bento.uri.type)))
            return self._add(deployment_pb, bento_pb, bento_pb.bento.uri.uri)

        except BentoMLException as error:
            deployment_pb.state.state = DeploymentState.ERROR
            deployment_pb.state.error_message = (
                f'Error creating SageMaker deployment: {str(error)}')
            return ApplyDeploymentResponse(status=error.status_proto,
                                           deployment=deployment_pb)
Exemplo n.º 23
0
 def _load_from_dict(self, model):
     if 'embedder' not in model:
         raise InvalidArgument(
             "'embedder' key is not found in the dictionary. "
             "Expecting a dictionary with keys 'model', 'tokenizer' and 'embedder'"
         )
     self._model = model['embedder']
Exemplo n.º 24
0
    def delete(self, deployment_pb):
        try:
            deployment_spec = deployment_pb.spec
            sagemaker_config = deployment_spec.sagemaker_operator_config
            sagemaker_config.region = (sagemaker_config.region
                                       or get_default_aws_region())
            if not sagemaker_config.region:
                raise InvalidArgument('AWS region is missing')

            sagemaker_client = boto3.client('sagemaker',
                                            sagemaker_config.region)
            _, _, sagemaker_endpoint_name = _get_sagemaker_resource_names(
                deployment_pb)

            try:
                delete_endpoint_response = sagemaker_client.delete_endpoint(
                    EndpointName=sagemaker_endpoint_name)
                logger.debug("AWS delete endpoint response: %s",
                             delete_endpoint_response)
            except ClientError as e:
                raise _aws_client_error_to_bentoml_exception(e)

            _try_clean_up_sagemaker_deployment_resource(deployment_pb)

            return DeleteDeploymentResponse(status=Status.OK())
        except BentoMLException as error:
            return DeleteDeploymentResponse(status=error.status_proto)
Exemplo n.º 25
0
    def delete(self, deployment_pb):
        try:
            deployment_spec = deployment_pb.spec
            ec2_deployment_config = deployment_spec.aws_ec2_operator_config
            ec2_deployment_config.region = (ec2_deployment_config.region
                                            or get_default_aws_region())
            if not ec2_deployment_config.region:
                raise InvalidArgument("AWS region is missing")

            _, deployment_stack_name, repository_name, _ = generate_ec2_resource_names(
                deployment_pb.namespace, deployment_pb.name)
            # delete stack
            delete_cloudformation_stack(deployment_stack_name,
                                        ec2_deployment_config.region)

            # delete repo from ecr
            delete_ecr_repository(repository_name,
                                  ec2_deployment_config.region)

            # remove bucket
            if deployment_pb.state.info_json:
                deployment_info_json = json.loads(
                    deployment_pb.state.info_json)
                bucket_name = deployment_info_json.get('S3Bucket')
                if bucket_name:
                    cleanup_s3_bucket_if_exist(bucket_name,
                                               ec2_deployment_config.region)

            return DeleteDeploymentResponse(status=Status.OK())
        except BentoMLException as error:
            return DeleteDeploymentResponse(status=error.status_proto)
Exemplo n.º 26
0
    def _load_from_dict(self, model):
        if not model.get('state_dict'):
            raise InvalidArgument(
                "'state_dict' key is not found in the dictionary. "
                "Expecting a dictionary with keys 'state_dict', 'emoji_list', 'pax_list', 'vocabulary'"
            )

        self._model = model
Exemplo n.º 27
0
    def add(self, artifact):
        if not isinstance(artifact, BentoServiceArtifactWrapper):
            raise InvalidArgument(
                "ArtifactCollection only accepts type BentoServiceArtifactWrapper,"
                "Must call BentoServiceArtifact#pack or BentoServiceArtifact#load "
                "before adding to an ArtifactCollection"
            )

        super(ArtifactCollection, self).__setitem__(artifact.spec.name, artifact)
Exemplo n.º 28
0
    def pack(self, model):  # pylint:disable=arguments-differ
        fastai_module = _import_fastai_module()

        if not isinstance(model, fastai_module.basic_train.Learner):
            raise InvalidArgument(
                "Expect `model` argument to be `fastai.basic_train.Learner` instance"
            )

        return _FastaiModelArtifactWrapper(self, model)
Exemplo n.º 29
0
    def decorator(bento_service_cls):
        artifact_names = set()
        for artifact in artifacts:
            if not isinstance(artifact, BentoServiceArtifact):
                raise InvalidArgument(
                    "BentoService @artifacts decorator only accept list of "
                    "BentoServiceArtifact instances, instead got type: '%s'" %
                    type(artifact))

            if artifact.name in artifact_names:
                raise InvalidArgument(
                    "Duplicated artifact name `%s` detected. Each artifact within one"
                    "BentoService must have an unique name" % artifact.name)

            artifact_names.add(artifact.name)

        bento_service_cls._declared_artifacts = artifacts
        return bento_service_cls
Exemplo n.º 30
0
def raise_if_api_names_not_found_in_bento_service_metadata(
        metadata, api_names):
    all_api_names = [api.name for api in metadata.apis]

    if not set(api_names).issubset(all_api_names):
        raise InvalidArgument("Expect api names {api_names} to be "
                              "subset of {all_api_names}".format(
                                  api_names=api_names,
                                  all_api_names=all_api_names))