def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
    base_name = 'foo'
    estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT,
                          train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session,
                          base_job_name=base_name)
    estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
    sagemaker_session.create_model_from_job.return_value = JOB_NAME

    strategy = 'MultiRecord'
    assemble_with = 'Line'
    kms_key = 'key'
    accept = 'text/csv'
    max_concurrent_transforms = 1
    max_payload = 6
    env = {'FOO': 'BAR'}

    transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
                                        output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
                                        max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
                                        env=env, role=ROLE)

    sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
    assert transformer.strategy == strategy
    assert transformer.assemble_with == assemble_with
    assert transformer.output_path == OUTPUT_PATH
    assert transformer.output_kms_key == kms_key
    assert transformer.accept == accept
    assert transformer.max_concurrent_transforms == max_concurrent_transforms
    assert transformer.max_payload == max_payload
    assert transformer.env == env
    assert transformer.base_transform_job_name == base_name
    assert transformer.tags == TAGS
def test_estimator_transformer_creation(sagemaker_session):
    estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT,
                          train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session)
    estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
    sagemaker_session.create_model_from_job.return_value = JOB_NAME

    transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

    sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
    assert isinstance(transformer, Transformer)
    assert transformer.sagemaker_session == sagemaker_session
    assert transformer.instance_count == INSTANCE_COUNT
    assert transformer.instance_type == INSTANCE_TYPE
    assert transformer.model_name == JOB_NAME
    assert transformer.tags is None