Exemple #1
0
 def test_max_sizes_incorrect_sizes(self):
     batch = [
         torch.rand(1, 20, 40),
         torch.rand(1, 20, 45),
         torch.rand(1, 20, 40)
     ]
     sizes = (1, 20, 40)
     with self.assertRaises(AssertionError):
         PaddingCollater.get_max_sizes(batch, sizes=sizes)
Exemple #2
0
 def test_max_sizes_no_sizes(self):
     batch = [
         torch.rand(3, 20, 40),
         torch.rand(3, 25, 30),
         torch.rand(5, 15, 35)
     ]
     max_sizes = PaddingCollater.get_max_sizes(batch)
     expected = (len(batch), 5, 25, 40)
     self.assertEqual(expected, max_sizes)
Exemple #3
0
 def test_max_sizes_with_channels_size(self):
     C = 1
     batch = [
         torch.rand(C, 20, 40),
         torch.rand(C, 25, 30),
         torch.rand(C, 15, 35)
     ]
     sizes = (C, None, None)
     max_sizes = PaddingCollater.get_max_sizes(batch, sizes=sizes)
     expected = (len(batch), C, 25, 40)
     self.assertEqual(expected, max_sizes)
Exemple #4
0
 def test_max_sizes_all_sizes_set(self):
     C = 1
     batch = [
         torch.rand(C, 20, 40),
         torch.rand(C, 20, 40),
         torch.rand(C, 20, 40)
     ]
     sizes = (C, 20, 40)
     max_sizes = PaddingCollater.get_max_sizes(batch, sizes=sizes)
     expected = (len(batch), C, 20, 40)
     self.assertEqual(expected, max_sizes)