def test_prepare_for_training_calculate_batch_size_1(sagemaker_session): lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') lr._prepare_for_training(data) assert lr.mini_batch_size == 1
def test_prepare_for_training_multiple_channel(sagemaker_session): lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=10000, feature_dim=FEATURE_DIM, channel='train') lr._prepare_for_training([data, data]) assert lr.mini_batch_size == DEFAULT_MINI_BATCH_SIZE
def test_prepare_for_training_multiple_channel_no_train(sagemaker_session): lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=10000, feature_dim=FEATURE_DIM, channel='mock') with pytest.raises(ValueError) as ex: lr._prepare_for_training([data, data]) assert 'Must provide train channel.' in str(ex)
def test_prepare_for_training_calculate_batch_size_2(sagemaker_session): lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet( "s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=10000, feature_dim=FEATURE_DIM, channel="train", ) lr._prepare_for_training(data) assert lr.mini_batch_size == DEFAULT_MINI_BATCH_SIZE