Beispiel #1
0
 def test_len_error(self, n):
     torch.manual_seed(123)
     a = ((torch.randn(5, 3), torch.randint(200, (n, 2))), torch.randn(n))
     ds = DatasetTuple(*a)
     with pytest.raises(RuntimeError) as ex:
         len(ds)
     assert str(ex.value) == "Need all tensors to have same lenght."
Beispiel #2
0
 def test_getitem(self):
     torch.manual_seed(123)
     n = 10
     a = ((torch.randn(n, 3), torch.randint(200, (n, 2))), torch.randn(n))
     ds = DatasetTuple(*a)
     assert_tupletree_equal(ds[0], ds[[0]])
     assert_tupletree_equal(ds[0], ds[:1])
     assert_tupletree_equal(ds[2:5], ds[[2, 3, 4]])
Beispiel #3
0
 def test_next_iter(self, batch_size, num_workers):
     torch.manual_seed(123)
     n = 20
     a = ((torch.randn(n, 3), torch.randint(200, (n, 2))), torch.randn(n))
     a = tuplefy(a)
     ds = DatasetTuple(*a)
     dl = DataLoaderBatch(ds, batch_size, False, num_workers=num_workers)
     a = a.iloc[:batch_size]
     b = next(iter(dl))
     assert_tupletree_equal(a, b)
Beispiel #4
0
 def test_len(self, n):
     torch.manual_seed(123)
     a = ((torch.randn(n, 3), torch.randint(200, (n, 2))), torch.randn(n))
     ds = DatasetTuple(*a)
     assert len(ds) == n