Exemple #1
0
 def test_max_sizes_incorrect_sizes(self):
     batch = [
         torch.rand(1, 20, 40),
         torch.rand(1, 20, 45),
         torch.rand(1, 20, 40)
     ]
     sizes = (1, 20, 40)
     with self.assertRaises(AssertionError):
         PaddingCollater.get_max_sizes(batch, sizes=sizes)
Exemple #2
0
 def val_dataloader(self):
     return torch.utils.data.DataLoader(
         dataset=self.va_ds,
         batch_size=self.batch_size,
         shuffle=False,
         num_workers=0,
         collate_fn=PaddingCollater({"img": (1, None, None)}),
     )
Exemple #3
0
 def test_max_sizes_no_sizes(self):
     batch = [
         torch.rand(3, 20, 40),
         torch.rand(3, 25, 30),
         torch.rand(5, 15, 35)
     ]
     max_sizes = PaddingCollater.get_max_sizes(batch)
     expected = (len(batch), 5, 25, 40)
     self.assertEqual(expected, max_sizes)
Exemple #4
0
 def test_collate_tensors(self):
     batch = [
         torch.rand(3, 20, 40),
         torch.rand(3, 25, 30),
         torch.rand(5, 15, 35)
     ]
     max_sizes = (len(batch), 5, 25, 40)
     collated = PaddingCollater.collate_tensors(batch, max_sizes)
     self.check_collated(batch, max_sizes, collated)
Exemple #5
0
 def test_collate_with_tensor_and_fixed_sizes(self):
     sizes = (1, 20, 40)
     collate_fn = PaddingCollater(sizes)
     batch = [
         torch.rand(1, 20, 40),
         torch.rand(1, 20, 40),
         torch.rand(1, 20, 40)
     ]
     x = collate_fn(batch)
     torch.testing.assert_allclose(x, torch.stack(batch))
Exemple #6
0
 def test_max_sizes_with_channels_size(self):
     C = 1
     batch = [
         torch.rand(C, 20, 40),
         torch.rand(C, 25, 30),
         torch.rand(C, 15, 35)
     ]
     sizes = (C, None, None)
     max_sizes = PaddingCollater.get_max_sizes(batch, sizes=sizes)
     expected = (len(batch), C, 25, 40)
     self.assertEqual(expected, max_sizes)
Exemple #7
0
 def test_max_sizes_all_sizes_set(self):
     C = 1
     batch = [
         torch.rand(C, 20, 40),
         torch.rand(C, 20, 40),
         torch.rand(C, 20, 40)
     ]
     sizes = (C, 20, 40)
     max_sizes = PaddingCollater.get_max_sizes(batch, sizes=sizes)
     expected = (len(batch), C, 20, 40)
     self.assertEqual(expected, max_sizes)
Exemple #8
0
 def test_collate_with_numpy(self):
     sizes = (1, 20, 40)
     collate_fn = PaddingCollater(sizes)
     batch = [
         np.random.rand(1, 20, 40),
         np.random.rand(1, 20, 40),
         np.random.rand(1, 20, 40),
     ]
     x = collate_fn(batch)
     torch.testing.assert_allclose(
         x, torch.stack([torch.from_numpy(x) for x in batch]))
Exemple #9
0
 def test_dataloader(self) -> DataLoader:
     assert self.te_ds is not None
     return DataLoader(
         dataset=self.te_ds,
         batch_size=self.batch_size,
         sampler=self.get_unpadded_distributed_sampler(self.te_ds),
         num_workers=self.num_workers,
         pin_memory=self.trainer.on_gpu,
         collate_fn=PaddingCollater(
             {"img": (self.img_channels, None, None)},
             sort_key=by_descending_width),
     )
Exemple #10
0
 def test_collate_with_tensor(self):
     sizes = (None, None, None)
     collate_fn = PaddingCollater(sizes)
     batch = [
         torch.rand(3, 20, 40),
         torch.rand(3, 25, 30),
         torch.rand(5, 15, 35)
     ]
     x, xs = collate_fn(batch)
     for i, b in enumerate(batch):
         self.assertEqual(list(b.size()), xs[i].tolist())
     self.check_collated(batch, (3, 5, 25, 40), x)
Exemple #11
0
 def train_dataloader(self) -> DataLoader:
     assert self.tr_ds is not None
     return DataLoader(
         dataset=self.tr_ds,
         batch_size=self.batch_size,
         num_workers=self.num_workers,
         shuffle=self.shuffle_tr,
         worker_init_fn=DataModule.worker_init_fn,
         pin_memory=self.trainer.on_gpu,
         collate_fn=PaddingCollater(
             {"img": (self.img_channels, None, None)},
             sort_key=by_descending_width),
     )
Exemple #12
0
 def test_collate_with_dict(self):
     sizes = {"img": (3, None, None)}
     collate_fn = PaddingCollater(sizes)
     batch = [
         {
             "img": torch.rand(3, 20, 40)
         },
         {
             "img": torch.rand(3, 25, 30)
         },
         {
             "img": torch.rand(3, 15, 35)
         },
     ]
     x, xs = collate_fn(batch)["img"]
     for i, b in enumerate(batch):
         self.assertEqual(list(b["img"].size()), xs[i].tolist())
     self.check_collated([b["img"] for b in batch], (3, 3, 25, 40), x)
Exemple #13
0
 def test_collate_with_list(self):
     sizes = [(None, None, None), (1, None, None)]
     collate_fn = PaddingCollater(sizes)
     batch = [
         [
             torch.rand(3, 20, 40),
             torch.rand(3, 25, 30),
             torch.rand(5, 15, 35)
         ],
         [
             torch.rand(1, 20, 40),
             torch.rand(1, 25, 30),
             torch.rand(1, 15, 35)
         ],
     ]
     expected = [(3, 5, 25, 40), (3, 1, 25, 40)]
     for i, (x, xs) in enumerate(collate_fn(batch)):
         for j in range(len(batch)):
             self.assertEqual(list(batch[i][j].size()), xs[j].tolist())
         self.check_collated(batch[i], expected[i], x)