예제 #1
0
def test_dataloaders():
    train_ds, val_ds, test_ds = DummyDataset(), DummyDataset(), DummyDataset()
    dm = DataModule(train_ds, val_ds, test_ds, num_workers=0)
    for dl in [
            dm.train_dataloader(),
            dm.val_dataloader(),
            dm.test_dataloader(),
    ]:
        x, y = next(iter(dl))
        assert x.shape == (1, 1, 28, 28)
예제 #2
0
def test_dataloaders_with_sampler(mock_dataloader):
    train_ds = val_ds = test_ds = 'dataset'
    mock_sampler = 'sampler'
    dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler)
    assert dm.sampler is mock_sampler
    dl = dm.train_dataloader()
    kwargs = mock_dataloader.call_args[1]
    assert 'sampler' in kwargs
    assert kwargs['sampler'] is mock_sampler
    for dl in [dm.val_dataloader(), dm.test_dataloader()]:
        kwargs = mock_dataloader.call_args[1]
        assert 'sampler' not in kwargs
예제 #3
0
def test_init():
    train_input = DatasetInput(RunningStage.TRAINING, DummyDataset())
    val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset())
    test_input = DatasetInput(RunningStage.TESTING, DummyDataset())

    data_module = DataModule(train_input, batch_size=1)
    assert data_module.train_dataset and not data_module.val_dataset and not data_module.test_dataset

    data_module = DataModule(train_input, val_input, batch_size=1)
    assert data_module.train_dataset and data_module.val_dataset and not data_module.test_dataset

    data_module = DataModule(train_input, val_input, test_input, batch_size=1)
    assert data_module.train_dataset and data_module.val_dataset and data_module.test_dataset
예제 #4
0
def test_dataloaders_with_sampler(mock_dataloader):
    train_ds = val_ds = test_ds = "dataset"
    mock_sampler = mock.MagicMock()
    dm = DataModule(train_ds,
                    val_ds,
                    test_ds,
                    num_workers=0,
                    sampler=mock_sampler)
    assert dm.sampler is mock_sampler
    dl = dm.train_dataloader()
    kwargs = mock_dataloader.call_args[1]
    assert "sampler" in kwargs
    assert kwargs["sampler"] is mock_sampler.return_value
    for dl in [dm.val_dataloader(), dm.test_dataloader()]:
        kwargs = mock_dataloader.call_args[1]
        assert "sampler" not in kwargs
예제 #5
0
def test_dataloaders():
    train_input = DatasetInput(RunningStage.TRAINING, DummyDataset())
    val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset())
    test_input = DatasetInput(RunningStage.TESTING, DummyDataset())
    dm = DataModule(train_input,
                    val_input,
                    test_input,
                    num_workers=0,
                    batch_size=1)
    for dl in [
            dm.train_dataloader(),
            dm.val_dataloader(),
            dm.test_dataloader(),
    ]:
        x = next(iter(dl))[DataKeys.INPUT]
        assert x.shape == (1, 1, 28, 28)
예제 #6
0
def test_cpu_count_none():
    train_ds = DummyDataset()
    dm = DataModule(train_ds, num_workers=None)
    if platform.system() == "Darwin" or platform.system() == "Windows":
        assert dm.num_workers == 0
    else:
        assert dm.num_workers > 0
예제 #7
0
def test_cpu_count_none():
    train_ds = DummyDataset()
    # with patch("os.cpu_count", return_value=None), pytest.warns(UserWarning, match="Could not infer"):
    dm = DataModule(train_ds, num_workers=None)
    if platform.system() == "Darwin":
        assert dm.num_workers == 0
    else:
        assert dm.num_workers > 0
예제 #8
0
def test_init():
    train_ds, val_ds, test_ds = DummyDataset(), DummyDataset(), DummyDataset()
    DataModule(train_ds)
    DataModule(train_ds, val_ds)
    DataModule(train_ds, val_ds, test_ds)
    assert DataModule().data_pipeline
예제 #9
0
def test_cpu_count_none():
    train_ds = DummyDataset()
    dm = DataModule(train_ds, num_workers=None)
    assert dm.num_workers == 0
예제 #10
0
def test_cpu_count_none():
    train_ds = DummyDataset()
    # with patch("os.cpu_count", return_value=None), pytest.warns(UserWarning, match="Could not infer"):
    dm = DataModule(train_ds, num_workers=None)
    assert dm.num_workers == 0