Exemplo n.º 1
0
def test_marketplace_model(sagemaker_session, cpu_instance_type):
    region = sagemaker_session.boto_region_name
    account = REGION_ACCOUNT_MAP[region]
    model_package_arn = MODEL_PACKAGE_ARN % (region, account)

    def predict_wrapper(endpoint, session):
        return sagemaker.RealTimePredictor(
            endpoint, session, serializer=sagemaker.predictor.csv_serializer)

    model = ModelPackage(
        role="SageMakerRole",
        model_package_arn=model_package_arn,
        sagemaker_session=sagemaker_session,
        predictor_cls=predict_wrapper,
    )

    endpoint_name = "test-marketplace-model-endpoint{}".format(
        sagemaker_timestamp())
    with timeout_and_delete_endpoint_by_name(endpoint_name,
                                             sagemaker_session,
                                             minutes=20):
        predictor = model.deploy(1,
                                 cpu_instance_type,
                                 endpoint_name=endpoint_name)
        data_path = os.path.join(DATA_DIR, "marketplace", "training")
        shape = pandas.read_csv(os.path.join(data_path, "iris.csv"),
                                header=None)
        a = [50 * i for i in range(3)]
        b = [40 + i for i in range(10)]
        indices = [i + j for i, j in itertools.product(a, b)]

        test_data = shape.iloc[indices[:-1]]
        test_x = test_data.iloc[:, 1:]

        print(predictor.predict(test_x.values).decode("utf-8"))
Exemplo n.º 2
0
def test_marketplace_model(sagemaker_session):
    def predict_wrapper(endpoint, session):
        return sagemaker.RealTimePredictor(
            endpoint, session, serializer=sagemaker.predictor.csv_serializer)

    model = ModelPackage(
        role='SageMakerRole',
        model_package_arn=(MODEL_PACKAGE_ARN %
                           sagemaker_session.boto_region_name),
        sagemaker_session=sagemaker_session,
        predictor_cls=predict_wrapper)

    endpoint_name = 'test-marketplace-model-endpoint{}'.format(
        sagemaker_timestamp())
    with timeout_and_delete_endpoint_by_name(endpoint_name,
                                             sagemaker_session,
                                             minutes=20):
        predictor = model.deploy(1,
                                 'ml.m4.xlarge',
                                 endpoint_name=endpoint_name)
        data_path = os.path.join(DATA_DIR, 'marketplace', 'training')
        shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'),
                                header=None)
        a = [50 * i for i in range(3)]
        b = [40 + i for i in range(10)]
        indices = [i + j for i, j in itertools.product(a, b)]

        test_data = shape.iloc[indices[:-1]]
        test_x = test_data.iloc[:, 1:]

        print(predictor.predict(test_x.values).decode('utf-8'))
                            EndpointName=args.endpoint_name,
                            EndpointConfigName=ep_config_name
                        )

        create_config('Y')
    except ClientError as error: 
        # endpoint does not exist
        if "Could not find endpoint" in error.response['Error']['Message']: 
            model_package_approved = get_approved_package(args.model_package_group_name)
            model_package_arn = model_package_approved["ModelPackageArn"]

            model = ModelPackage(role=args.role, 
                                 model_package_arn=model_package_arn, 
                                 sagemaker_session=sagemaker_session)
            try:
                model.deploy(initial_instance_count=args.initial_instance_count, 
                             instance_type=args.endpoint_instance_type,
                             endpoint_name=args.endpoint_name)
                create_config('Y')
            except ClientError as error:
                print(error.response['Error']['Message'])
                create_config('N')
                error_message = error.response["Error"]["Message"]
                LOGGER.error("{}".format(stacktrace))
                raise Exception(error_message)
        else:
            print(error.response['Error']['Message'])
            create_config('N')
            error_message = error.response["Error"]["Message"]
            LOGGER.error("{}".format(stacktrace))
            raise Exception(error_message)