def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] _kwg = { "persistent_workers": persistent_workers } if pytorch_after(1, 8) else {} data_list = list(range(1, 11)) dataset = CacheDataset(data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False) self.assertListEqual(expected, list(dataset)) loader = DataLoader( CacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) dataset = SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ) self.assertListEqual(expected[:7], list(dataset)) loader = DataLoader( SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected[:7], [y.item() for y in loader]) self.assertListEqual(expected[:7], [y.item() for y in loader]) with tempfile.TemporaryDirectory() as tempdir: pdata = PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir) self.assertListEqual(expected, list(pdata)) loader = DataLoader( PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir), batch_size=1, num_workers=loader_workers, shuffle=False, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader])
def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple self.version_too_new = pytorch_after(*pytorch_version_tuple)
def meshgrid_ij(*tensors): if pytorch_after(1, 10): return torch.meshgrid(*tensors, indexing="ij") return torch.meshgrid(*tensors)
def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple self.version_too_old = not pytorch_after(*pytorch_version_tuple)