def test_alias_check_fail_multiple_operators(self): with self.assertRaisesRegex( RuntimeError, "Argument input is not defined to alias output but was aliasing" ): x = torch.rand((3, 3), requires_grad=True) y = torch.zeros((3, 3), requires_grad=True) with enable_torch_dispatch_mode(SchemaCheckMode()): IncorrectAliasTensor(x).sin().relu().add( IncorrectAliasTensor(y), alpha=2)
def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): actual = torch.rand((3, 3)) y = actual schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): actual.add_(y) self.assertEqual([('aten::add_', 'input'), ('aten::add_', 'other')], schema_check.mutated) self.assertEqual([('aten::add_', 'input', 'output_0'), ('aten::add_', 'other', 'output_0')], schema_check.aliasing)
def test_mutation_check_fail_multiple_operators(self): with self.assertRaisesRegex( RuntimeError, "Argument running_mean is not defined as mutable but was mutated" ): x = torch.rand((3, 3), requires_grad=True) batch = torch.nn.BatchNorm1d(3, track_running_stats=True) with enable_torch_dispatch_mode(SchemaCheckMode()): x = x.sinh() x = x.tanh() x = x.relu() batch(x)
def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): x = torch.arange(9.) m_actual = torch.arange(9.) e_actual = torch.zeros([9], dtype=torch.int32) schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): torch.frexp(x, out=(m_actual, e_actual)) self.assertEqual([('aten::frexp', 'mantissa'), ('aten::frexp', 'exponent')], schema_check.mutated) self.assertEqual([('aten::frexp', 'mantissa', 'output_0'), ('aten::frexp', 'exponent', 'output_1')], schema_check.aliasing)
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): x = torch.rand((3, 3)) actual = torch.zeros(3) schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual([('aten::aminmax', 'min'), ('aten::aminmax', 'max')], schema_check.mutated) self.assertEqual([('aten::aminmax', 'min', 'output_0'), ('aten::aminmax', 'min', 'output_1'), ('aten::aminmax', 'max', 'output_0'), ('aten::aminmax', 'max', 'output_1')], schema_check.aliasing)
def test_schema_check_mode_functionality_nested_training_op(self): actual = torch.rand((3, 3)) batch = torch.nn.BatchNorm1d(3, track_running_stats=True) expected = torch.clone(actual) expected.sinh_() expected.tanh_() expected.relu_() expected = batch(expected) with enable_torch_dispatch_mode(SchemaCheckMode()): actual.sinh_() actual.tanh_() actual.relu_() actual = batch(actual) self.assertEqual(expected, actual)
def test_schema_check_mode_functionality_kwarg_tensor(self): x = torch.rand((3, 5)) w = torch.rand((4)) expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) self.assertEqual(expected, actual)
def test_schema_check_mode_mutated_aliasing_as_strided(self): x = torch.rand((3, 6, 4)) schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): x.as_strided_([3, 6, 4], [9, 1, 1]) self.assertEqual( [ ('aten::as_strided_', 'input') ], schema_check.mutated ) self.assertEqual( [ ('aten::as_strided_', 'input', 'output_0') ], schema_check.aliasing )
def test_schema_check_tensor_operator_order_without_grad(self): schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): x = torch.rand((3, 3), requires_grad=False) x.relu().sin() self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
def test_alias_check_fail_outputs_unexpectedly_aliasing(self): with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): x = torch.rand((3, 3)) s = SchemaCheckMode() with enable_torch_dispatch_mode(s): IncorrectAliasTensor(x).aminmax(dim=0)
def test_mutation_check_fail_multiple_operators(self): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): x = torch.rand((3, 3)) y = torch.rand((3, 3)) with enable_torch_dispatch_mode(SchemaCheckMode()): IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
def test_schema_check_mode_empty_list_input(self): expected = torch.atleast_1d([]) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = torch.atleast_1d([]) self.assertEqual(expected, actual)
def test_schema_check_mode_functionality_device_input(self): with enable_torch_dispatch_mode(SchemaCheckMode()): x = torch.rand((3, 3), device="cpu", dtype=torch.double) y = x + x self.assertEqual(x + x, y)
def test_schema_check_mode_functionality_wildcard_after(self): x = torch.rand((3, 3)) expected = x.chunk(6) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = x.chunk(6) self.assertEqual(expected, actual)
def test_schema_check_tensor_functionality(self): x = torch.rand((3, 3), requires_grad=True) expected = x.relu().sin() with enable_torch_dispatch_mode(SchemaCheckMode()): actual = x.relu().sin() self.assertEqual(expected, actual)
def test_schema_check_tensor_functionality_default_replaced(self): x = torch.rand((3, 3), requires_grad=True) expected = x.add(x, alpha=2) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = x.add(x, alpha=2) self.assertEqual(expected, actual)
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): x = torch.rand((3, 3)) actual = torch.zeros(3) with enable_torch_dispatch_mode(SchemaCheckMode()): torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual(torch.amax(x, dim=0), actual)