def test_prepare_for_training_calculate_batch_size_1(sagemaker_session):
    lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')

    lr._prepare_for_training(data)

    assert lr.mini_batch_size == 1
def test_prepare_for_training_multiple_channel(sagemaker_session):
    lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX),
                     num_records=10000,
                     feature_dim=FEATURE_DIM,
                     channel='train')

    lr._prepare_for_training([data, data])

    assert lr.mini_batch_size == DEFAULT_MINI_BATCH_SIZE
def test_prepare_for_training_multiple_channel_no_train(sagemaker_session):
    lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX),
                     num_records=10000,
                     feature_dim=FEATURE_DIM,
                     channel='mock')

    with pytest.raises(ValueError) as ex:
        lr._prepare_for_training([data, data])

    assert 'Must provide train channel.' in str(ex)
def test_prepare_for_training_calculate_batch_size_2(sagemaker_session):
    lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=10000,
        feature_dim=FEATURE_DIM,
        channel="train",
    )

    lr._prepare_for_training(data)

    assert lr.mini_batch_size == DEFAULT_MINI_BATCH_SIZE