Exemple #1
0
 def test_set_input(self):
     inp = {
         torch.ones(size=(32, 3, 32, 32), dtype=torch.float64),
         torch.ones(size=(32, 10), dtype=torch.int8)
     }
     model = torch.nn.Sequential(torch.nn.Linear(28 * 28, 100))
     spec = FeInputSpec(inp, model)
     result = spec.get_dummy_input()
     self.assertIsInstance(result, set,
                           "Spec should return a set of tensors")
     self.assertEqual(2, len(result), "Spec should return two tensors")
Exemple #2
0
 def test_simple_input(self):
     inp = torch.ones(size=(32, 1, 28, 28), dtype=torch.float16)
     model = torch.nn.Sequential(torch.nn.Linear(28 * 28, 100))
     spec = FeInputSpec(inp, model)
     result = spec.get_dummy_input()
     self.assertIsInstance(result, torch.Tensor,
                           "Spec should return a torch tensor")
     self.assertListEqual([32, 1, 28, 28], list(result.shape),
                          "Result shape is incorrect")
     self.assertEqual(torch.float16, result.dtype,
                      "Result should be float16 dtype")
Exemple #3
0
 def test_map_input(self):
     inp = {'inp': torch.ones(size=(32, 3, 32, 32), dtype=torch.float64)}
     model = torch.nn.Sequential(torch.nn.Linear(28 * 28, 100))
     spec = FeInputSpec(inp, model)
     result = spec.get_dummy_input()
     self.assertIsInstance(result, dict,
                           "Spec should return a tuple of tensors")
     self.assertEqual(1, len(result), "Spec should return two tensors")
     self.assertIsInstance(result['inp'], torch.Tensor,
                           "Spec should return a torch tensor")
     self.assertListEqual([32, 3, 32, 32], list(result['inp'].shape),
                          "Result shape is incorrect")
     self.assertEqual(torch.float64, result['inp'].dtype,
                      "Result dtype is incorrect")
Exemple #4
0
 def test_list_input(self):
     inp = [
         torch.ones(size=(32, 3, 32, 32), dtype=torch.float64),
         torch.ones(size=(32, 10), dtype=torch.int8)
     ]
     model = torch.nn.Sequential(torch.nn.Linear(28 * 28, 100))
     spec = FeInputSpec(inp, model)
     result = spec.get_dummy_input()
     self.assertIsInstance(result, list,
                           "Spec should return a list of tensors")
     self.assertEqual(2, len(result), "Spec should return two tensors")
     self.assertIsInstance(result[0], torch.Tensor,
                           "Spec should return a torch tensor")
     self.assertListEqual([32, 3, 32, 32], list(result[0].shape),
                          "Result shape is incorrect")
     self.assertEqual(torch.float64, result[0].dtype,
                      "Result dtype is incorrect")
     self.assertIsInstance(result[1], torch.Tensor,
                           "Spec should return a torch tensor")
     self.assertListEqual([32, 10], list(result[1].shape),
                          "Result shape is incorrect")
     self.assertEqual(torch.int8, result[1].dtype,
                      "Result dtype is incorrect")