Пример #1
0
 def test_mutation_check_fail_multiple_operators(self):
     with self.assertRaises(RuntimeError):
         x = torch.rand((3, 3), requires_grad=True)
         x.sinh_()
         x.tanh_()
         x.relu_()
         batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
         batch(SchemaCheckTensor(x))
Пример #2
0
 def test_mutation_check_fail(self):
     with self.assertRaises(RuntimeError):
         x = torch.rand((3, 3), requires_grad=True)
         batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
         batch(SchemaCheckTensor(x))
Пример #3
0
 def test_schema_check_tensor_functionality_mutable_inputs(self):
     x = torch.rand((3, 3), requires_grad=False)
     y = torch.clone(x)
     x.sinh_()
     SchemaCheckTensor(y).sinh_()
     self.assertEqual(x, y)
Пример #4
0
 def test_schema_check_tensor_functionality_default_replaced(self):
     x = torch.rand((3, 3), requires_grad=True)
     self.assertEqual(
         x.add(x, alpha=2),
         SchemaCheckTensor(x).add(SchemaCheckTensor(x), alpha=2).elem)
Пример #5
0
 def test_schema_check_tensor_functionality(self):
     x = torch.rand((3, 3), requires_grad=True)
     self.assertEqual(x.relu().sin(),
                      SchemaCheckTensor(x).relu().sin().elem)
Пример #6
0
 def test_schema_check_tensor_operator_order_no_grad(self):
     x = torch.rand((3, 3), requires_grad=False)
     SchemaCheckTensor(x).relu().sin()
     self.assertEqual(["aten::relu", "aten::sin"],
                      SchemaCheckTensor.recorded_ops)
Пример #7
0
 def setUp(self):
     SchemaCheckTensor.reset_cache()