class BatchEndpointDeploymentSchema(EndpointDeploymentSchema): scale_settings = PatchedNested(DeploymentScaleSettingsSchema) batch_settings = PatchedNested(BatchDeploymentSettingsSchema) compute = PatchedNested(ComputeBindingSchema) @post_load def make(self, data: Any, **kwargs: Any) -> InternalBatchEndpointDeployment: return InternalBatchEndpointDeployment( base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
class BatchDeploymentSettingsSchema(PatchedBaseSchema): compute_id = fields.Str() partitioning_scheme = PatchedNested(BatchPartitioningSchemeSchema) output_configuration = PatchedNested(BatchOutputConfigurationSchema) error_threshold = fields.Int() retry_settings = PatchedNested(BatchRetrySettingsSchema) logging_level = fields.Str() @post_load def make(self, data: Any, **kwargs: Any) -> InternalDeploymentBatchSettings: return InternalDeploymentBatchSettings(**data)
class OnlineEndpointDeploymentSchema(EndpointDeploymentSchema): sku = fields.Str() app_insights_enabled = fields.Bool() resource_requirements = PatchedNested(ResourceRequirementsSchema, required=False) scale_settings = PatchedNested(ScaleSettingsSchema) request_settings = PatchedNested(RequestSettingsSchema) liveness_probe = PatchedNested(LivenessProbeSchema) provisioning_status = fields.Str(dump_only=True) @post_load def make(self, data: Any, **kwargs: Any) -> InternalOnlineEndpointDeployment: return InternalOnlineEndpointDeployment( base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
class EndpointDeploymentSchema(PathAwareSchema): id = fields.Str(dump_only=True) type = fields.Str(dump_only=True) tags = fields.Dict() properties = fields.Dict() model = UnionField([ ArmVersionedStr(asset_type=AssetType.MODEL), PatchedNested(ModelSchema) ], required=True) code_configuration = PatchedNested(CodeConfigurationSchema) environment = UnionField([ ArmVersionedStr(asset_type=AssetType.ENVIRONMENT), PatchedNested(EnvironmentSchema) ])
class OnlineEndpointSchema(EndpointSchema): # TODO: need to revisit here since the backend has some fields not available yet deployments = fields.Dict( keys=fields.Str(), values=PatchedNested(OnlineEndpointDeploymentSchema)) provisioning_status = fields.Str(dump_only=True) @post_load def make(self, data: Any, **kwargs: Any) -> InternalOnlineEndpoint: infra = data.get("infrastructure", None).lower() if infra == ComputeType.MANAGED: cluster_type = ComputeType.MANAGED else: cluster_type = ComputeType.AKS if not data.get("auth_mode", None) and (cluster_type == ComputeType.MANAGED or cluster_type == ComputeType.AKS): data["auth_mode"] = KEY deployments = data.get("deployments", None) if deployments: for name, deployment in deployments.items(): if not deployment.sku and cluster_type == ComputeType.MANAGED: raise ValidationError( "A sku must be specified for a managed inference cluster" ) deployment.name = name return InternalOnlineEndpoint( base_path=self.context[BASE_PATH_CONTEXT_KEY], **data, cluster_type=cluster_type)
class CodeConfigurationSchema(PathAwareSchema): code = UnionField([ArmVersionedStr(asset_type=AssetType.CODE), PatchedNested(CodeAssetSchema)]) scoring_script = fields.Str() @post_load def make(self, data: Any, **kwargs: Any) -> InternalCodeConfiguration: return InternalCodeConfiguration(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
class BatchEndpointSchema(EndpointSchema): deployments = fields.Dict( keys=fields.Str(), values=PatchedNested(BatchEndpointDeploymentSchema)) @post_load def make(self, data: Any, **kwargs: Any) -> InternalBatchEndpoint: return InternalBatchEndpoint( base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)