コード例 #1
0
    def __init__(
            self,
            trainingInstanceType,
            trainingInstanceCount,
            endpointInstanceType,
            endpointInitialInstanceCount,
            sagemakerRole=IAMRoleFromConfig(),
            requestRowSerializer=ProtobufRequestRowSerializer(),
            responseRowDeserializer=LinearLearnerBinaryClassifierProtobufResponseRowDeserializer(),
            trainingInputS3DataPath=S3AutoCreatePath(),
            trainingOutputS3DataPath=S3AutoCreatePath(),
            trainingInstanceVolumeSizeInGB=1024,
            trainingProjectedColumns=None,
            trainingChannelName="train",
            trainingContentType=None,
            trainingS3DataDistribution="ShardedByS3Key",
            trainingSparkDataFormat="sagemaker",
            trainingSparkDataFormatOptions=None,
            trainingInputMode="File",
            trainingCompressionCodec=None,
            trainingMaxRuntimeInSeconds=24*60*60,
            trainingKmsKeyId=None,
            modelEnvironmentVariables=None,
            endpointCreationPolicy=EndpointCreationPolicy.CREATE_ON_CONSTRUCT,
            sagemakerClient=SageMakerClients.create_sagemaker_client(),
            region=None,
            s3Client=SageMakerClients.create_s3_default_client(),
            stsClient=SageMakerClients.create_sts_default_client(),
            modelPrependInputRowsToTransformationRows=True,
            deleteStagingDataAfterTraining=True,
            namePolicyFactory=RandomNamePolicyFactory(),
            uid=None,
            javaObject=None):

        if trainingSparkDataFormatOptions is None:
            trainingSparkDataFormatOptions = {}

        if modelEnvironmentVariables is None:
            modelEnvironmentVariables = {}

        if uid is None:
            uid = Identifiable._randomUID()

        kwargs = locals().copy()
        del kwargs['self']

        super(LinearLearnerBinaryClassifier, self).__init__(**kwargs)

        default_params = {
            'predictor_type': 'binary_classifier'
        }

        self._setDefault(**default_params)
コード例 #2
0
def test_linearLearnerBinaryClassifier_passes_correct_params_to_scala():

    training_instance_type = "c4.8xlarge"
    training_instance_count = 3
    endpoint_instance_type = "c4.8xlarge"
    endpoint_initial_instance_count = 3

    training_bucket = "random-bucket"
    input_prefix = "linear-learner-binary-classifier-training"
    output_prefix = "linear-learner-binary-classifier-out"
    integTestingRole = "arn:aws:iam::123456789:role/SageMakerRole"

    estimator = LinearLearnerBinaryClassifier(
        trainingInstanceType=training_instance_type,
        trainingInstanceCount=training_instance_count,
        endpointInstanceType=endpoint_instance_type,
        endpointInitialInstanceCount=endpoint_initial_instance_count,
        sagemakerRole=IAMRole(integTestingRole),
        requestRowSerializer=ProtobufRequestRowSerializer(),
        responseRowDeserializer=
        LinearLearnerBinaryClassifierProtobufResponseRowDeserializer(),
        trainingInstanceVolumeSizeInGB=2048,
        trainingInputS3DataPath=S3DataPath(training_bucket, input_prefix),
        trainingOutputS3DataPath=S3DataPath(training_bucket, output_prefix),
        trainingMaxRuntimeInSeconds=1,
        endpointCreationPolicy=EndpointCreationPolicy.CREATE_ON_TRANSFORM,
        sagemakerClient=SageMakerClients.create_sagemaker_client(),
        s3Client=SageMakerClients.create_s3_default_client(),
        stsClient=SageMakerClients.create_sts_default_client(),
        modelPrependInputRowsToTransformationRows=True,
        namePolicyFactory=RandomNamePolicyFactory(),
        uid="sagemaker")

    assert estimator.trainingInputS3DataPath.bucket == training_bucket
    assert estimator.trainingInputS3DataPath.objectPath == input_prefix
    assert estimator.trainingInstanceCount == training_instance_count
    assert estimator.trainingInstanceType == training_instance_type
    assert estimator.endpointInstanceType == endpoint_instance_type
    assert estimator.endpointInitialInstanceCount == endpoint_initial_instance_count
    assert estimator.trainingInstanceVolumeSizeInGB == 2048
    assert estimator.trainingMaxRuntimeInSeconds == 1
    assert estimator.trainingKmsKeyId is None