Exemplo n.º 1
0
def deployment_yaml_to_pb(deployment_yaml):
    deployment_pb = Deployment()

    if deployment_yaml.get('name') is not None:
        deployment_pb.name = deployment_yaml.get('name')
    if deployment_yaml.get('namespace') is not None:
        deployment_pb.namespace = deployment_yaml.get('namespace')
    if deployment_yaml.get('labels') is not None:
        deployment_pb.labels.update(dict(deployment_yaml.get('labels')))
    if deployment_yaml.get('annotations') is not None:
        deployment_pb.annotations.update(
            dict(deployment_yaml.get('annotations')))

    spec_yaml = deployment_yaml.get('spec')
    platform = spec_yaml.get('operator')
    if platform is not None:
        deployment_pb.spec.operator = DeploymentSpec.DeploymentOperator.Value(
            platform.replace('-', '_').upper())
    if spec_yaml.get('bento_name'):
        deployment_pb.spec.bento_name = spec_yaml.get('bento_name')
    if spec_yaml.get('bento_version'):
        deployment_pb.spec.bento_version = spec_yaml.get('bento_version')

    if platform == 'aws_sagemaker':
        sagemaker_config = spec_yaml.get('sagemaker_operator_config')
        sagemaker_operator_config_pb = deployment_pb.spec.sagemaker_operator_config
        if sagemaker_config.get('api_name'):
            sagemaker_operator_config_pb.api_name = sagemaker_config.get(
                'api_name')
        if sagemaker_config.get('region'):
            sagemaker_operator_config_pb.region = sagemaker_config.get(
                'region')
        if sagemaker_config.get('instance_count'):
            sagemaker_operator_config_pb.instance_count = sagemaker_config.get(
                'instance_count')
        if sagemaker_config.get('instance_type'):
            sagemaker_operator_config_pb.instance_type = sagemaker_config.get(
                'instance_type')
    elif platform == 'aws_lambda':
        lambda_config = spec_yaml.get('aws_lambda_operator_config')
        if lambda_config.get('region'):
            deployment_pb.spec.aws_lambda_config.region = lambda_config.get(
                'region')
    elif platform == 'gcp_function':
        gcp_config = spec_yaml.get('gcp_function_operator_config')
        if gcp_config.get('region'):
            deployment_pb.spec.gcp_function_operator_config.region = gcp_config.get(
                'region')
    elif platform == 'kubernetes':
        k8s_config = spec_yaml.get('kubernetes_operator_config')
        k8s_operator_config_pb = deployment_pb.spec.kubernetes_operator_config

        if k8s_config.get('kube_namespace'):
            k8s_operator_config_pb.kube_namespace = k8s_config.get(
                'kube_namespace')
        if k8s_config.get('replicas'):
            k8s_operator_config_pb.replicas = k8s_config.get('replicas')
        if k8s_config.get('service_name'):
            k8s_operator_config_pb.service_name = k8s_config.get(
                'service_name')
        if k8s_config.get('service_type'):
            k8s_operator_config_pb.service_type = k8s_config.get(
                'service_type')
    else:
        raise BentoMLException(
            'Custom deployment is not supported in the current version of BentoML'
        )

    return deployment_pb
Exemplo n.º 2
0
def deployment_dict_to_pb(deployment_dict):
    deployment_pb = Deployment()
    if deployment_dict.get('name') is not None:
        deployment_pb.name = deployment_dict.get('name')
    if deployment_dict.get('namespace') is not None:
        deployment_pb.namespace = deployment_dict.get('namespace')
    if deployment_dict.get('labels') is not None:
        deployment_pb.labels.update(deployment_dict.get('labels'))
    if deployment_dict.get('annotations') is not None:
        deployment_pb.annotations.update(deployment_dict.get('annotations'))

    if deployment_dict.get('spec'):
        spec_dict = deployment_dict.get('spec')
    else:
        raise BentoMLDeploymentException(
            '"spec" is required field for deployment')
    platform = spec_dict.get('operator')
    if platform is not None:
        # converting platform parameter to DeploymentOperator name in proto
        # e.g. 'aws-lambda' to 'AWS_LAMBDA'
        deployment_pb.spec.operator = DeploymentSpec.DeploymentOperator.Value(
            platform.replace('-', '_').upper())

    if spec_dict.get('bento_name'):
        deployment_pb.spec.bento_name = spec_dict.get('bento_name')
    if spec_dict.get('bento_version'):
        deployment_pb.spec.bento_version = spec_dict.get('bento_version')

    if deployment_pb.spec.operator == DeploymentSpec.AWS_SAGEMAKER:
        sagemaker_config = spec_dict.get('sagemaker_operator_config', {})
        sagemaker_operator_config_pb = deployment_pb.spec.sagemaker_operator_config
        if sagemaker_config.get('api_name'):
            sagemaker_operator_config_pb.api_name = sagemaker_config.get(
                'api_name')
        if sagemaker_config.get('region'):
            sagemaker_operator_config_pb.region = sagemaker_config.get(
                'region')
        if sagemaker_config.get('instance_count'):
            sagemaker_operator_config_pb.instance_count = int(
                sagemaker_config.get('instance_count'))
        if sagemaker_config.get('instance_type'):
            sagemaker_operator_config_pb.instance_type = sagemaker_config.get(
                'instance_type')
    elif deployment_pb.spec.operator == DeploymentSpec.AWS_LAMBDA:
        lambda_config = spec_dict.get('aws_lambda_operator_config', {})
        if lambda_config.get('region'):
            deployment_pb.spec.aws_lambda_operator_config.region = lambda_config.get(
                'region')
        if lambda_config.get('api_name'):
            deployment_pb.spec.aws_lambda_operator_config.api_name = lambda_config.get(
                'api_name')
    elif deployment_pb.spec.operator == DeploymentSpec.GCP_FUNCTION:
        gcp_config = spec_dict.get('gcp_function_operator_config', {})
        if gcp_config.get('region'):
            deployment_pb.spec.gcp_function_operator_config.region = gcp_config.get(
                'region')
        if gcp_config.get('api_name'):
            deployment_pb.spec.aws_lambda_operator_config.api_name = gcp_config.get(
                'api_name')
    elif deployment_pb.spec.operator == DeploymentSpec.KUBERNETES:
        k8s_config = spec_dict.get('kubernetes_operator_config', {})
        k8s_operator_config_pb = deployment_pb.spec.kubernetes_operator_config

        if k8s_config.get('kube_namespace'):
            k8s_operator_config_pb.kube_namespace = k8s_config.get(
                'kube_namespace')
        if k8s_config.get('replicas'):
            k8s_operator_config_pb.replicas = k8s_config.get('replicas')
        if k8s_config.get('service_name'):
            k8s_operator_config_pb.service_name = k8s_config.get(
                'service_name')
        if k8s_config.get('service_type'):
            k8s_operator_config_pb.service_type = k8s_config.get(
                'service_type')
    else:
        raise BentoMLException(
            'Platform "{}" is not supported in the current version of '
            'BentoML'.format(platform))

    return deployment_pb