Ejemplo n.º 1
0
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session):
    kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    with pytest.raises(ValueError):
        kmeans._prepare_for_training(data, 0)
Ejemplo n.º 2
0
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
    kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    kmeans._prepare_for_training(data)

    assert kmeans.mini_batch_size == 5000
Ejemplo n.º 3
0
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session):
    kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    with pytest.raises(ValueError):
        kmeans._prepare_for_training(data, 0)
Ejemplo n.º 4
0
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
    kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    kmeans._prepare_for_training(data)

    assert kmeans.mini_batch_size == 5000
Ejemplo n.º 5
0
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
    kmeans = KMeans(base_job_name='kmeans',
                    sagemaker_session=sagemaker_session,
                    **ALL_REQ_ARGS)

    data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX),
                     num_records=1,
                     feature_dim=FEATURE_DIM,
                     channel='train')

    with pytest.raises((TypeError, ValueError)):
        kmeans._prepare_for_training(data, 'some')