Example #1
0
 def test_packed_fail4(self, packed_tensor, total_numel, last_dim, dtype,
                       wrong_device):
     with pytest.raises(TypeError,
                        match='tensor device is cpu, should be cuda'):
         testing.check_packed_tensor(packed_tensor, total_numel, last_dim,
                                     dtype, wrong_device)
     assert not testing.check_packed_tensor(packed_tensor, total_numel,
                                            last_dim,
                                            dtype, wrong_device, throw=False)
Example #2
0
 def test_packed_fail3(self, packed_tensor, total_numel, last_dim,
                       wrong_dtype, device):
     with pytest.raises(TypeError,
                        match='tensor dtype is torch.float32, should be torch.int64'):
         testing.check_packed_tensor(packed_tensor, total_numel, last_dim,
                                     wrong_dtype, device)
     assert not testing.check_packed_tensor(packed_tensor, total_numel,
                                            last_dim,
                                            wrong_dtype, device, throw=False)
Example #3
0
 def test_packed_fail2(self, packed_tensor, total_numel, wrong_last_dim,
                       dtype, device):
     with pytest.raises(ValueError,
                        match='tensor last_dim is 2, should be 3'):
         testing.check_packed_tensor(packed_tensor, total_numel,
                                     wrong_last_dim, dtype, device)
     assert not testing.check_packed_tensor(packed_tensor, total_numel,
                                            wrong_last_dim,
                                            dtype, device, throw=False)
Example #4
0
 def test_packed_fail1(self, packed_tensor, wrong_total_numel, last_dim,
                       dtype, device):
     with pytest.raises(ValueError,
                        match='tensor total number of elements is 5, should be 6'):
         testing.check_packed_tensor(packed_tensor, wrong_total_numel,
                                     last_dim, dtype, device)
     assert not testing.check_packed_tensor(packed_tensor, wrong_total_numel,
                                            last_dim,
                                            dtype, device, throw=False)
Example #5
0
 def test_list_to_packed_to_list(self, tensor_list, shape_per_tensor,
                                 first_idx,
                                 last_dim, dtype, device):
     packed_tensor, output_shape_per_tensor = batch.list_to_packed(
         tensor_list)
     assert torch.equal(output_shape_per_tensor, shape_per_tensor)
     check_packed_tensor(packed_tensor, total_numel=first_idx[-1],
                         last_dim=last_dim,
                         dtype=dtype, device=device)
     for i, tensor in enumerate(tensor_list):
         assert torch.equal(packed_tensor[first_idx[i]:first_idx[i + 1]],
                            tensor.reshape(-1, last_dim))
     output_tensor_list = batch.packed_to_list(packed_tensor,
                                               shape_per_tensor, first_idx)
     for output_tensor, expected_tensor in zip(output_tensor_list,
                                               tensor_list):
         assert torch.equal(output_tensor, expected_tensor)
Example #6
0
 def test_packed_default_success(self, packed_tensor):
     assert testing.check_packed_tensor(packed_tensor)
Example #7
0
 def test_packed_success(self, packed_tensor, total_numel, last_dim, dtype,
                         device):
     assert testing.check_packed_tensor(packed_tensor, total_numel,
                                        last_dim, dtype, device)