Exemplo n.º 1
0
    def test_thread_safe(self, persistent_workers, cache_workers,
                         loader_workers):
        expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002]
        _kwg = {
            "persistent_workers": persistent_workers
        } if pytorch_after(1, 8) else {}
        data_list = list(range(1, 11))
        dataset = CacheDataset(data=data_list,
                               transform=_StatefulTransform(),
                               cache_rate=1.0,
                               num_workers=cache_workers,
                               progress=False)
        self.assertListEqual(expected, list(dataset))
        loader = DataLoader(
            CacheDataset(
                data=data_list,
                transform=_StatefulTransform(),
                cache_rate=1.0,
                num_workers=cache_workers,
                progress=False,
            ),
            batch_size=1,
            num_workers=loader_workers,
            **_kwg,
        )
        self.assertListEqual(expected, [y.item() for y in loader])
        self.assertListEqual(expected, [y.item() for y in loader])

        dataset = SmartCacheDataset(
            data=data_list,
            transform=_StatefulTransform(),
            cache_rate=0.7,
            replace_rate=0.5,
            num_replace_workers=cache_workers,
            progress=False,
            shuffle=False,
        )
        self.assertListEqual(expected[:7], list(dataset))
        loader = DataLoader(
            SmartCacheDataset(
                data=data_list,
                transform=_StatefulTransform(),
                cache_rate=0.7,
                replace_rate=0.5,
                num_replace_workers=cache_workers,
                progress=False,
                shuffle=False,
            ),
            batch_size=1,
            num_workers=loader_workers,
            **_kwg,
        )
        self.assertListEqual(expected[:7], [y.item() for y in loader])
        self.assertListEqual(expected[:7], [y.item() for y in loader])

        with tempfile.TemporaryDirectory() as tempdir:
            pdata = PersistentDataset(data=data_list,
                                      transform=_StatefulTransform(),
                                      cache_dir=tempdir)
            self.assertListEqual(expected, list(pdata))
            loader = DataLoader(
                PersistentDataset(data=data_list,
                                  transform=_StatefulTransform(),
                                  cache_dir=tempdir),
                batch_size=1,
                num_workers=loader_workers,
                shuffle=False,
                **_kwg,
            )
            self.assertListEqual(expected, [y.item() for y in loader])
            self.assertListEqual(expected, [y.item() for y in loader])
Exemplo n.º 2
0
 def __init__(self, pytorch_version_tuple):
     self.max_version = pytorch_version_tuple
     self.version_too_new = pytorch_after(*pytorch_version_tuple)
Exemplo n.º 3
0
def meshgrid_ij(*tensors):
    if pytorch_after(1, 10):
        return torch.meshgrid(*tensors, indexing="ij")
    return torch.meshgrid(*tensors)
Exemplo n.º 4
0
 def __init__(self, pytorch_version_tuple):
     self.min_version = pytorch_version_tuple
     self.version_too_old = not pytorch_after(*pytorch_version_tuple)