示例#1
0
def test_incorrect_test_train_split(dummy_data):

    dataloader = BaseDataModule(data=dummy_data,
                                y_col='y',
                                train_test_split=42)
    with pytest.raises(ValueError):
        dataloader.prepare_data()
示例#2
0
def test_stratified_shuffle_regression_warning(dummy_data):

    with pytest.warns(SunpyUserWarning):
        dataloader = BaseDataModule(data=dummy_data,
                                    y_col='y',
                                    is_regression=True,
                                    stratified_shuffle=True,
                                    weighted_sampling=False)
        dataloader.prepare_data()

    assert dataloader.stratified_shuffle is False
示例#3
0
def test_disjoint_split(dummy_data):
    dataloader = BaseDataModule(data=dummy_data, y_col='y')
    dataloader.prepare_data()
    dataloader.setup()
    assert set(dataloader.train.index).isdisjoint(dataloader.val.index)
    assert set(dataloader.train.index).isdisjoint(dataloader.test.index)
    assert set(dataloader.test.index).isdisjoint(dataloader.val.index)
示例#4
0
def test_split_distribution(dummy_data):
    dataloader = BaseDataModule(data=dummy_data, y_col='y')
    dataloader.prepare_data()
    dataloader.setup()
    data_dist = distribution(dataloader.data['y'])
    assert not np.any(
        np.abs(distribution(dataloader.train['y']) - data_dist) > 0.1)
    assert not np.any(
        np.abs(distribution(dataloader.test['y']) - data_dist) > 0.1)
    assert not np.any(
        np.abs(distribution(dataloader.val['y']) - data_dist) > 0.1)
示例#5
0
def test_oversampling(dummy_data):
    train_conf = {'is_tabular': True, 'transform': ToTensor()}
    dataloader = BaseDataModule(data=dummy_data,
                                y_col='y',
                                batch_size=1,
                                train_conf=train_conf)
    dataloader.prepare_data()
    dataloader.setup()

    sampled_labels = []
    for index, sample in enumerate(dataloader.train_dataloader()):
        sampled_labels.append(sample[1].numpy()[0])

    data_distribution = distribution(sampled_labels)

    assert np.abs(data_distribution[0] - data_distribution[1]) < 0.2
示例#6
0
def test_conf_passed(dummy_data):
    training_conf = {}
    testing_conf = {}
    validation_conf = {}

    dataloader = BaseDataModule(data=dummy_data,
                                y_col='y',
                                train_conf=training_conf,
                                test_conf=testing_conf,
                                val_conf=validation_conf)

    dataloader.prepare_data()

    with pytest.warns(None) as record:
        dataloader.setup()
    assert not record
示例#7
0
def test_no_conf(dummy_data):
    training_conf = {}
    testing_conf = {}
    validation_conf = {}

    dataloader_no_train_conf = BaseDataModule(data=dummy_data,
                                              y_col='y',
                                              test_conf=testing_conf,
                                              val_conf=validation_conf)

    dataloader_no_train_conf.prepare_data()
    with pytest.warns(SunpyUserWarning):
        dataloader_no_train_conf.setup()

    dataloader_no_val_conf = BaseDataModule(data=dummy_data,
                                            y_col='y',
                                            test_conf=testing_conf,
                                            train_conf=training_conf)

    dataloader_no_val_conf.prepare_data()
    with pytest.warns(SunpyUserWarning):
        dataloader_no_val_conf.setup()

    dataloader_no_test_conf = BaseDataModule(data=dummy_data,
                                             y_col='y',
                                             train_conf=training_conf,
                                             val_conf=validation_conf)

    dataloader_no_test_conf.prepare_data()
    with pytest.warns(SunpyUserWarning):
        dataloader_no_test_conf.setup()
示例#8
0
def test_incorrect_data():

    dataloader = BaseDataModule(data=42, y_col='y')
    with pytest.raises(TypeError):
        dataloader.prepare_data()
示例#9
0
def test_default_loader(dummy_data):
    BaseDataModule(data=dummy_data, y_col='y')
示例#10
0
def test_X_y_disjoint(dummy_data):

    with pytest.raises(ValueError):
        dataloader = BaseDataModule(data=dummy_data,
                                    y_col='y',
                                    X_col=['X', 'y'])
示例#11
0
def test_no_X_col(dummy_data):

    with pytest.warns(SunpyUserWarning):
        dataloader = BaseDataModule(data=dummy_data, y_col='y')
        dataloader.prepare_data()