def test_prepare_for_training_feature_dim_greater_than_max_allowed(sagemaker_session):
    randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
                                      **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=MAX_FEATURE_DIM + 1,
                     channel='train')

    with pytest.raises((TypeError, ValueError)):
        randomcutforest._prepare_for_training(data)
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
    randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
                                      **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')

    with pytest.raises((TypeError, ValueError)):
        randomcutforest._prepare_for_training(data, 1234)
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
    randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
                                      **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    randomcutforest._prepare_for_training(data)

    assert randomcutforest.mini_batch_size == MINI_BATCH_SIZE
def test_prepare_for_training_feature_dim_greater_than_max_allowed(sagemaker_session):
    randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
                                      **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=MAX_FEATURE_DIM + 1,
                     channel='train')

    with pytest.raises((TypeError, ValueError)):
        randomcutforest._prepare_for_training(data)
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
    randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
                                      **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')

    with pytest.raises((TypeError, ValueError)):
        randomcutforest._prepare_for_training(data, 1234)
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
    randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
                                      **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    randomcutforest._prepare_for_training(data)

    assert randomcutforest.mini_batch_size == MINI_BATCH_SIZE