def __init__(
            self,
            trainingInstanceType,
            trainingInstanceCount,
            endpointInstanceType,
            endpointInitialInstanceCount,
            sagemakerRole=IAMRoleFromConfig(),
            requestRowSerializer=ProtobufRequestRowSerializer(),
            responseRowDeserializer=LinearLearnerBinaryClassifierProtobufResponseRowDeserializer(),
            trainingInputS3DataPath=S3AutoCreatePath(),
            trainingOutputS3DataPath=S3AutoCreatePath(),
            trainingInstanceVolumeSizeInGB=1024,
            trainingProjectedColumns=None,
            trainingChannelName="train",
            trainingContentType=None,
            trainingS3DataDistribution="ShardedByS3Key",
            trainingSparkDataFormat="sagemaker",
            trainingSparkDataFormatOptions=None,
            trainingInputMode="File",
            trainingCompressionCodec=None,
            trainingMaxRuntimeInSeconds=24*60*60,
            trainingKmsKeyId=None,
            modelEnvironmentVariables=None,
            endpointCreationPolicy=EndpointCreationPolicy.CREATE_ON_CONSTRUCT,
            sagemakerClient=SageMakerClients.create_sagemaker_client(),
            region=None,
            s3Client=SageMakerClients.create_s3_default_client(),
            stsClient=SageMakerClients.create_sts_default_client(),
            modelPrependInputRowsToTransformationRows=True,
            deleteStagingDataAfterTraining=True,
            namePolicyFactory=RandomNamePolicyFactory(),
            uid=None,
            javaObject=None):

        if trainingSparkDataFormatOptions is None:
            trainingSparkDataFormatOptions = {}

        if modelEnvironmentVariables is None:
            modelEnvironmentVariables = {}

        if uid is None:
            uid = Identifiable._randomUID()

        kwargs = locals().copy()
        del kwargs['self']

        super(LinearLearnerBinaryClassifier, self).__init__(**kwargs)

        default_params = {
            'predictor_type': 'binary_classifier'
        }

        self._setDefault(**default_params)
Exemple #2
0
    def __init__(self,
                 endpointInstanceType,
                 endpointInitialInstanceCount,
                 requestRowSerializer,
                 responseRowDeserializer,
                 existingEndpointName=None,
                 modelImage=None,
                 modelPath=None,
                 modelEnvironmentVariables=None,
                 modelExecutionRoleARN=None,
                 endpointCreationPolicy=EndpointCreationPolicy.CREATE_ON_CONSTRUCT,
                 sagemakerClient=SageMakerClients.create_sagemaker_client(),
                 prependResultRows=True,
                 namePolicy=RandomNamePolicy(),
                 uid=None,
                 javaObject=None):

        super(SageMakerModel, self).__init__()

        if modelEnvironmentVariables is None:
            modelEnvironmentVariables = {}

        if javaObject:
            self._java_obj = javaObject
        else:
            if uid is None:
                uid = Identifiable._randomUID()

            self._java_obj = self._new_java_obj(
                SageMakerModel._wrapped_class,
                Option(endpointInstanceType),
                Option(endpointInitialInstanceCount),
                requestRowSerializer,
                responseRowDeserializer,
                Option(existingEndpointName),
                Option(modelImage),
                Option(modelPath),
                modelEnvironmentVariables,
                Option(modelExecutionRoleARN),
                endpointCreationPolicy,
                sagemakerClient,
                prependResultRows,
                namePolicy,
                uid
            )
        self._resetUid(self._call_java("uid"))
def test_spark_integration():
    key = SparkMLParam(Identifiable(), "name", "doc")
    value = 123
    param = Param(key, value)
    assert param.key == "name"
    assert param.value == "123"