class TestEagerFusionOpInfo(TestCase): @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) # entries in here need don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', { xfail('linalg.cholesky'), skip('msort'), xfail('nn.functional.dropout'), xfail('to_sparse'), xfail('addcdiv'), xfail('cholesky'), xfail('cumulative_trapezoid'), xfail('diag_embed'), xfail('linalg.householder_product'), xfail('logit'), xfail('trapezoid'), xfail('trapz'), xfail('corrcoef'), xfail('cov'), skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes? skip('nn.functional.margin_ranking_loss'), # seems flaky }) def test_aot_autograd_exhaustive(self, device, dtype, op): def f(args, kwargs): return op.op(*args, **kwargs) if not op.supports_autograd: return sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) for sample_input in sample_inputs_itr: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in args]): self.skipTest("not all inputs are float tensors") if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in kwargs.values()]): self.skipTest("not all inputs are float tensors") continue t = f(args, kwargs) if isinstance(t, tuple): self.skipTest("output is a tuple") continue def reset_grads(): def f(x): x.grad = None pytree.tree_map(f, args) def get_grads(args): return pytree.tree_map(lambda x: x.grad, args) compiled_f = compiled_function(f, nop, nop) reset_grads() compiled_f(args, kwargs).sum().backward() compiled_grad = get_grads(args) reset_grads() f(args, kwargs).sum().backward() orig_grad = get_grads(args) self.assertEqual(orig_grad, compiled_grad) def create_new_arg(x): return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad) args = pytree.tree_map(create_new_arg, args) reset_grads() compiled_f(args, kwargs).sum().backward() compiled_grad = get_grads(args) reset_grads() f(args, kwargs).sum().backward() orig_grad = get_grads(args) self.assertEqual(orig_grad, compiled_grad)
mod.zero_grad() mod(inp).sum().backward() grads2 = [a.grad for a in mod.parameters()] self.assertEqual(grads, grads2) make_fx_failures = { xfail('allclose'), xfail('nn.functional.dropout'), xfail('linalg.eigvals'), xfail('nn.functional.max_pool1d', device_type='cpu'), # precision problems? xfail('randn_like'), # randomness xfail('rand_like'), # randomness xfail('randint_like'), # randomness skip('new_empty'), # nondeterministic skip('empty_like'), # nondeterministic skip('linalg.lstsq', 'grad_oriented'), # flaky xfail('normal', '', device_type='cpu'), xfail('normal', 'number_mean', device_type='cpu'), xfail('multinomial', device_type='cpu'), xfail('nn.functional.feature_alpha_dropout', 'with_train', device_type='cpu'), xfail('bernoulli', device_type='cpu'), xfail('nn.functional.dropout2d', device_type='cpu'), skip('nn.functional.max_unpool1d', '', device_type='cpu'), # flaky skip('nn.functional.max_unpool2d', '', device_type='cpu'), # flaky skip('nn.functional.max_unpool3d', '', device_type='cpu'), # flaky skip('linalg.lstsq'), # flaky, probably just a precision issue xfail('histogram'),