def test_estimator_transformer_creation_with_optional_params(sagemaker_session): base_name = 'foo' estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session, base_job_name=base_name) estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) sagemaker_session.create_model_from_job.return_value = JOB_NAME strategy = 'MultiRecord' assemble_with = 'Line' kms_key = 'key' accept = 'text/csv' max_concurrent_transforms = 1 max_payload = 6 env = {'FOO': 'BAR'} transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env, role=ROLE) sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE) assert transformer.strategy == strategy assert transformer.assemble_with == assemble_with assert transformer.output_path == OUTPUT_PATH assert transformer.output_kms_key == kms_key assert transformer.accept == accept assert transformer.max_concurrent_transforms == max_concurrent_transforms assert transformer.max_payload == max_payload assert transformer.env == env assert transformer.base_transform_job_name == base_name assert transformer.tags == TAGS
def test_estimator_transformer_creation(sagemaker_session): estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) sagemaker_session.create_model_from_job.return_value = JOB_NAME transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE) sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None) assert isinstance(transformer, Transformer) assert transformer.sagemaker_session == sagemaker_session assert transformer.instance_count == INSTANCE_COUNT assert transformer.instance_type == INSTANCE_TYPE assert transformer.model_name == JOB_NAME assert transformer.tags is None
def main(): download_training_and_eval_data() image = 'sagemaker-tensorflow2-batch-transform-local' env = { "MODEL_SERVER_WORKERS": "2" } print('Starting model training.') california_housing_estimator = Estimator( image, DUMMY_IAM_ROLE, hyperparameters={'epochs': 10, 'batch_size': 64, 'learning_rate': 0.1}, instance_count=1, instance_type="local") inputs = {'train': 'file://./data/train', 'test': 'file://./data/test'} california_housing_estimator.fit(inputs, logs=True) print('Completed model training') print('Running Batch Transform in local mode') tensorflow_serving_transformer = california_housing_estimator.transformer( instance_count=1, instance_type='local', output_path='file:./data/output', env = env ) tensorflow_serving_transformer.transform('file://./data/input', split_type='Line', content_type='text/csv') print('Printing Batch Transform output file content') output_file = open('./data/output/x_test.csv.out', 'r').read() print(output_file)