Esempio n. 1
0
class TestEagerFusionOpInfo(TestCase):
    @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
    # entries in here need don't work and need to be fixed.
    # Each one of these is a bug (or needs to be investigated)
    @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', {
        xfail('linalg.cholesky'),
        skip('msort'),
        xfail('nn.functional.dropout'),
        xfail('to_sparse'),
        xfail('addcdiv'),
        xfail('cholesky'),
        xfail('cumulative_trapezoid'),
        xfail('diag_embed'),
        xfail('linalg.householder_product'),
        xfail('logit'),
        xfail('trapezoid'),
        xfail('trapz'),
        xfail('corrcoef'),
        xfail('cov'),
        skip('nn.functional.binary_cross_entropy_with_logits'),  # seems to fail sometimes?
        skip('nn.functional.margin_ranking_loss'),  # seems flaky
    })
    def test_aot_autograd_exhaustive(self, device, dtype, op):
        def f(args, kwargs):
            return op.op(*args, **kwargs)
        if not op.supports_autograd:
            return
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
        for sample_input in sample_inputs_itr:
            args = [sample_input.input] + list(sample_input.args)
            kwargs = sample_input.kwargs
            if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in args]):
                self.skipTest("not all inputs are float tensors")
            if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in kwargs.values()]):
                self.skipTest("not all inputs are float tensors")
                continue
            t = f(args, kwargs)
            if isinstance(t, tuple):
                self.skipTest("output is a tuple")
                continue

            def reset_grads():
                def f(x):
                    x.grad = None
                pytree.tree_map(f, args)

            def get_grads(args):
                return pytree.tree_map(lambda x: x.grad, args)

            compiled_f = compiled_function(f, nop, nop)

            reset_grads()
            compiled_f(args, kwargs).sum().backward()
            compiled_grad = get_grads(args)

            reset_grads()
            f(args, kwargs).sum().backward()
            orig_grad = get_grads(args)
            self.assertEqual(orig_grad, compiled_grad)

            def create_new_arg(x):
                return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)

            args = pytree.tree_map(create_new_arg, args)

            reset_grads()
            compiled_f(args, kwargs).sum().backward()
            compiled_grad = get_grads(args)

            reset_grads()
            f(args, kwargs).sum().backward()
            orig_grad = get_grads(args)
            self.assertEqual(orig_grad, compiled_grad)
        mod.zero_grad()
        mod(inp).sum().backward()
        grads2 = [a.grad for a in mod.parameters()]
        self.assertEqual(grads, grads2)


make_fx_failures = {
    xfail('allclose'),
    xfail('nn.functional.dropout'),
    xfail('linalg.eigvals'),
    xfail('nn.functional.max_pool1d',
          device_type='cpu'),  # precision problems?
    xfail('randn_like'),  # randomness
    xfail('rand_like'),  # randomness
    xfail('randint_like'),  # randomness
    skip('new_empty'),  # nondeterministic
    skip('empty_like'),  # nondeterministic
    skip('linalg.lstsq', 'grad_oriented'),  # flaky
    xfail('normal', '', device_type='cpu'),
    xfail('normal', 'number_mean', device_type='cpu'),
    xfail('multinomial', device_type='cpu'),
    xfail('nn.functional.feature_alpha_dropout',
          'with_train',
          device_type='cpu'),
    xfail('bernoulli', device_type='cpu'),
    xfail('nn.functional.dropout2d', device_type='cpu'),
    skip('nn.functional.max_unpool1d', '', device_type='cpu'),  # flaky
    skip('nn.functional.max_unpool2d', '', device_type='cpu'),  # flaky
    skip('nn.functional.max_unpool3d', '', device_type='cpu'),  # flaky
    skip('linalg.lstsq'),  # flaky, probably just a precision issue
    xfail('histogram'),