Example #1
0
def test_batch_limit1() -> None:
    for shuffle in (False, True):
        num_samples, num_features = int(1e2), int(1e1)
        X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset,
                            batch_size=4,
                            num_workers=1,
                            shuffle=shuffle)
        loader = BatchLimitLoaderWrapper(loader, num_batches=1)

        batch1 = next(iter(loader))[0]
        batch2 = next(iter(loader))[0]
        batch3 = next(iter(loader))[0]
        assert all(torch.isclose(x, y).all() for x, y in zip(batch1, batch2))
        assert all(torch.isclose(x, y).all() for x, y in zip(batch2, batch3))
Example #2
0
    def on_epoch_start(self, runner: "IRunner") -> None:
        """Wraps loaders for current epoch.
        If number-of-batches for loader is not provided then the first batch
        from loader will be used for overfitting.

        Args:
            runner: current runner
        """
        epoch_loaders = OrderedDict()

        for name, loader in runner.loaders.items():
            num_batches = self.loader_batches.get(name, 1)
            if isinstance(num_batches, float):
                num_batches = int(len(loader) * num_batches)
            epoch_loaders[name] = BatchLimitLoaderWrapper(
                loader=loader, num_batches=num_batches)

        runner.loaders = epoch_loaders