def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
    base_name = 'foo'
    estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT,
                          train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session,
                          base_job_name=base_name)
    estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
    sagemaker_session.create_model_from_job.return_value = JOB_NAME

    strategy = 'MultiRecord'
    assemble_with = 'Line'
    kms_key = 'key'
    accept = 'text/csv'
    max_concurrent_transforms = 1
    max_payload = 6
    env = {'FOO': 'BAR'}

    transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
                                        output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
                                        max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
                                        env=env, role=ROLE)

    sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
    assert transformer.strategy == strategy
    assert transformer.assemble_with == assemble_with
    assert transformer.output_path == OUTPUT_PATH
    assert transformer.output_kms_key == kms_key
    assert transformer.accept == accept
    assert transformer.max_concurrent_transforms == max_concurrent_transforms
    assert transformer.max_payload == max_payload
    assert transformer.env == env
    assert transformer.base_transform_job_name == base_name
    assert transformer.tags == TAGS
def test_estimator_transformer_creation(sagemaker_session):
    estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT,
                          train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session)
    estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
    sagemaker_session.create_model_from_job.return_value = JOB_NAME

    transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

    sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
    assert isinstance(transformer, Transformer)
    assert transformer.sagemaker_session == sagemaker_session
    assert transformer.instance_count == INSTANCE_COUNT
    assert transformer.instance_type == INSTANCE_TYPE
    assert transformer.model_name == JOB_NAME
    assert transformer.tags is None
def main():
    download_training_and_eval_data()

    image = 'sagemaker-tensorflow2-batch-transform-local'

    env = {
        "MODEL_SERVER_WORKERS": "2"
    }

    print('Starting model training.')
    california_housing_estimator = Estimator(
        image,
        DUMMY_IAM_ROLE,
        hyperparameters={'epochs': 10,
                         'batch_size': 64,
                         'learning_rate': 0.1},
        instance_count=1,
        instance_type="local")

    inputs = {'train': 'file://./data/train', 'test': 'file://./data/test'}
    california_housing_estimator.fit(inputs, logs=True)
    print('Completed model training')

    print('Running Batch Transform in local mode')
    tensorflow_serving_transformer = california_housing_estimator.transformer(
        instance_count=1,
        instance_type='local',
        output_path='file:./data/output',
        env = env
    )

    tensorflow_serving_transformer.transform('file://./data/input',
                                             split_type='Line',
                                             content_type='text/csv')

    print('Printing Batch Transform output file content')
    output_file = open('./data/output/x_test.csv.out', 'r').read()
    print(output_file)