def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') with pytest.raises(ValueError): kmeans._prepare_for_training(data, 0)
def test_prepare_for_training_no_mini_batch_size(sagemaker_session): kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') kmeans._prepare_for_training(data) assert kmeans.mini_batch_size == 5000
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet( "s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel="train", ) with pytest.raises(ValueError): kmeans._prepare_for_training(data, 0)
def test_prepare_for_training_no_mini_batch_size(sagemaker_session): kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet( "s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel="train", ) kmeans._prepare_for_training(data) assert kmeans.mini_batch_size == 5000
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') with pytest.raises((TypeError, ValueError)): kmeans._prepare_for_training(data, 'some')