예제 #1
0
    def test_variant_consistency_eager(self, device, dtype, op):
        test_backward = op.test_complex_grad or not dtype.is_complex
        samples = op.sample_inputs(device, dtype, requires_grad=test_backward)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        for sample in samples:
            # Acquires variants to test
            method = op.get_method()
            inplace = op.get_inplace()
            variants = (v for v in (method, inplace) if v is not None)
            # Computes expected forward

            # below calls op's function variant
            expected_forward = op(*sample.input, *sample.args, **sample.kwargs)

            # Computes expected backward
            # NOTE: backward may fail for some dtypes
            exception_during_backwards = False
            expected_grad = None
            try:
                expected_forward.sum().backward()
                expected_grad = sample.input.grad
                sample.input.grad = None
            except Exception as e:
                exception_during_backwards = True

            # Test eager consistency
            for variant in variants:
                # Verifies that inplace operations that promote int->float fail
                #   on tensors with integer dtypes.
                if (variant is inplace
                        and not torch.can_cast(expected_forward.dtype, dtype)):
                    try:
                        variant_forward = variant(
                            *(clone_input_helper(input)
                              for input in sample.input), *sample.args,
                            **sample.kwargs)
                    except Exception as e:
                        continue
                    self.fail(
                        "Inplace operation on integer tensor that should be promoted to float didn't fail!"
                    )
                # Compares variant's forward
                # Note: copy the tensor-type inputs when testing inplace operation
                variant_forward = variant(
                    *(clone_input_helper(input)
                      if variant is inplace else input
                      for input in sample.input), *sample.args,
                    **sample.kwargs)
                self.assertEqual(variant_forward, expected_forward)

                # Compares variant's backward
                if test_backward and (variant is not inplace
                                      or op.test_inplace_grad):
                    self.check_variant_backward(sample.input, variant_forward,
                                                expected_grad,
                                                exception_during_backwards)
예제 #2
0
def can_cast(from_, to):
    """Wrapper of `torch.can_cast`.

    Parameters
    ----------
    from_ : data-type
        Data type to cast from.
    to : data-type
        Data type to cast to.
    """
    return torch.can_cast(from_, to)
예제 #3
0
    def test_unary_op_out_casting(self, device, dtypes):
        t = torch.tensor((1), dtype=dtypes[0], device=device)
        out = torch.empty(1, dtype=dtypes[1], device=device)

        ops = (torch.neg, )
        for op in ops:
            if torch.can_cast(dtypes[0], dtypes[1]):
                self.assertEqual(op(t, out=out), out)
            else:
                with self.assertRaisesRegex(RuntimeError, 'can\'t be cast'):
                    op(t, out=out)
예제 #4
0
 def forward(self, input: torch.Tensor) -> torch.Tensor:
     # To make it compatible with torchscript since torch.Size does not work
     t_shape = torch.tensor(input.shape)
     if self._shape is not None and list(t_shape) != list(self._shape):
         if self._broadcastable:
             if not self._broadcast(t_shape, self._shape):
                 raise ValueError(
                     f'Shapes {self._shape} and {input.shape} are non'
                     ' broadcastable')
         else:
             raise ValueError(
                 f'Expected {self._shape}, input shape is {input.shape}')
     if self._dtype is not None and input.dtype != self._dtype:
         if self._can_cast:
             if not torch.can_cast(input.dtype, self._dtype):
                 raise ValueError(
                     f'Input dtype {input.dtype} can\'t be casted to'
                     f' {self._dtype}')
         else:
             raise ValueError(
                 f'Expected {self._dtype}, input dtype is {input.dtype}')
     return input
예제 #5
0
    def test_jit_alias_remapping(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        # NOTE: only tests on first sample
        sample = samples[0]

        # [Scripting Data Preparation]
        # Prepare data for test scripting
        # Below we prepare strings of args/kwargs with and without type annotations.
        # These strings are inserted into function template strings which is then torch scripted.
        # - args string is ["t0"] corresponding to the "input" tensor required by the op
        # - args_annot_kw is the string for the template function signature, for example,
        # ["t0", "s0: float", "s1: bool", "max: float = 1.0", "min: float = 0.0"] ->
        #    def fn(t0, s0: float, s1: bool, max: float = 1.0, min: float = 0.0)
        # - args_kw is the string of args/kwargs used to call the op, same as args_annot_kw but
        # without type annotations
        args = ["t0"]
        args_annot_kw = args + \
            [f"s{i}: {type(v).__name__}" for i, v in enumerate(sample.args)] + \
            [f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()]
        args_kw = args + \
            [f"s{i}" for i in range(len(sample.args))] + \
            [f"{k}={v}" for k, v in sample.kwargs.items()]

        # Prepare data for test tracing
        sample_args_kwargs = ()
        if len(sample.args) > 0:
            sample_args_kwargs += (sample.args, )
        if len(sample.kwargs) > 0:
            sample_args_kwargs += (sample.kwargs, )

        original_name = op.aten_name
        original_name_inplace = original_name + "_"
        expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype

        for a_op in op.aliases:
            inplace = a_op.inplace_variant
            method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
            variants = (v for v in (a_op.op, a_op.method_variant,
                                    a_op.inplace_variant) if v is not None)

            # Test scripting:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                if variant in method_or_inplace:
                    fn_template = '''
                        def _fn(t0{c}{args_annot_kw}):
                            return t0.{alias_name}({args_kw})
                    '''
                    # remove the first input tensor
                    script = fn_template.format(
                        c=", " if len(args_kw[1:]) > 1
                        or len(args_annot_kw[1:]) >= 1 else "",
                        args_annot_kw=", ".join(args_annot_kw[1:]),
                        args_kw=", ".join(args_kw[1:]),
                        alias_name=variant_name,
                    )
                else:
                    fn_template = '''
                        def _fn({args_annot_kw}):
                            return variant({args_kw})
                    '''
                    script = fn_template.format(
                        args_annot_kw=", ".join(args_annot_kw),
                        args_kw=", ".join(args_kw),
                    )
                scripted = torch.jit.CompilationUnit(script)._fn

                if (variant is inplace
                        and not torch.can_cast(expected_dtype, dtype)):
                    try:
                        inp = clone_input_helper(sample.input)
                        scripted(inp, *sample.args, **sample.kwargs)
                    except Exception as e:
                        continue
                    self.fail(
                        "Inplace operation on integer tensor that should be promoted to float didn't fail!"
                    )

                inp = clone_input_helper(sample.input)
                scripted(inp, *sample.args, **sample.kwargs)
                inp = clone_input_helper(sample.input)
                graph = scripted.graph_for(inp, *sample.args, **sample.kwargs)
                FileCheck().check(
                    op.aten_name).check_not(variant_name).run(graph)

            # Test tracing:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                def _fn(*sample_args, **sample_kwargs):
                    return variant(*sample_args, **sample_kwargs)

                inp = (clone_input_helper(sample.input), ) + sample_args_kwargs
                traced = torch.jit.trace(_fn, *inp)
                inp = (clone_input_helper(sample.input), ) + sample_args_kwargs
                traced(*inp)
                inp = (clone_input_helper(sample.input), ) + sample_args_kwargs
                graph = traced.graph_for(*inp)
                FileCheck().check(op_name).check_not(variant_name).run(graph)
예제 #6
0
    def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device,
                        coalesced):
        if dtype1.is_complex or dtype2.is_complex:
            return

        suffix = '_' if inplace else ''
        err = "{} {}({}, {})".format(
            "  coalesced" if coalesced else "uncoalesced", op_name + suffix,
            dtype1, dtype2)

        def op(t1, t2):
            return getattr(t1, op_name + suffix)(t2)

        add_sub = op_name == 'add' or op_name == 'sub'

        (dense1,
         sparse1) = self._test_sparse_op_input_tensors(device, dtype1,
                                                       coalesced)
        (dense2,
         sparse2) = self._test_sparse_op_input_tensors(device, dtype2,
                                                       coalesced,
                                                       op_name != 'div')

        common_dtype = torch.result_type(dense1, dense2)
        if self.device_type == 'cpu' and common_dtype == torch.half:
            self.assertRaises(RuntimeError, lambda: op(s1, d2))

        # Skip inplace tests that would fail due to inability to cast to the output type.
        # Some of these would also raise errors due to not being a supported op.
        if inplace and not torch.can_cast(common_dtype, dtype1):
            self.assertRaises(RuntimeError, lambda: op(dense1, sparse2))
            self.assertRaises(RuntimeError, lambda: op(sparse1, sparse2))
            self.assertRaises(RuntimeError, lambda: op(sparse1, dense2))
            return

        expected = op(dense1.clone(), dense2)
        precision = self._get_precision(expected.dtype, coalesced)
        test_tensors = [expected, dense1, sparse1, dense2, sparse2]
        e, d1, s1, d2, s2 = [x.clone() for x in test_tensors
                             ] if inplace else test_tensors

        # Test op(sparse, sparse)
        if op_name != 'div':
            sparse = op(s1, s2)
            self.assertEqual(sparse.dtype, e.dtype)
            self.assertEqual(e, sparse.to_dense(), atol=precision, message=err)
        else:
            # sparse division only supports division by a scalar
            self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense())

        # Test op(dense, sparse)
        if add_sub:
            if inplace:
                e, d1, s1, d2, s2 = [x.clone() for x in test_tensors]
            dense_sparse = op(d1, s2)
            self.assertEqual(e, dense_sparse, atol=precision, message=err)
        else:
            # sparse division only supports division by a scalar
            # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'
            self.assertRaises(RuntimeError, lambda: op(d1, s2))

        # Test op(sparse, dense) not supported for any ops:
        # add(sparse, dense) is not supported. Use add(dense, sparse) instead.
        # sparse division only supports division by a scalar
        # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'.
        self.assertRaises(RuntimeError, lambda: op(s1, d2))

        # Test op(sparse, scalar)
        if not add_sub and not (self.device_type == 'cpu'
                                and dtype1 == torch.half):
            if inplace:
                e, d1, s1, d2, s2 = [x.clone() for x in test_tensors]
            scalar = d2.view(d2.numel())[0].item()

            sparse = op(s1, scalar)
            dense_scalar = op(d1, scalar)
            self.assertEqual(sparse.dtype, dense_scalar.dtype)
            self.assertEqual(dense_scalar,
                             sparse.to_dense(),
                             atol=precision,
                             message=err)
        else:
            # add(sparse, dense) is not supported. Use add(dense, sparse) instead.
            # "mul_cpu" / "div_cpu" not implemented for 'Half'
            self.assertRaises(RuntimeError,
                              lambda: op(s1,
                                         d2.view(d2.numel())[0].item()))
예제 #7
0
 def test_can_cast(self, device):
     self.assertTrue(torch.can_cast(torch.double, torch.float))
     self.assertFalse(torch.can_cast(torch.float, torch.int))
예제 #8
0
    def test_jit_alias_remapping(self, device, dtype, op):
        # Required to avoid undefined value: tensor error in JIT compilation of the function template
        tensor = torch.tensor

        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        # NOTE: only tests on first sample
        sample = samples[0]

        # [Scripting Data Preparation]
        # Prepare data for test scripting
        # Below we prepare strings of args/kwargs with and without type annotations.
        # These strings are inserted into function template strings which is then torch scripted.
        # - args string is ["t0"] corresponding to the "input" tensor required by the op
        # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
        # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
        args = ["t0"]

        def quote_strs(v):
            if isinstance(v, str):
                return f"'{v}'"

            return str(v)

        args_kw = args + \
            [f"{v}" for v in sample.args] + \
            [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]

        # Prepare data for test tracing
        sample_args_kwargs = ()
        if len(sample.args) > 0:
            sample_args_kwargs += (sample.args, )
        if len(sample.kwargs) > 0:
            sample_args_kwargs += (sample.kwargs, )

        original_name = op.aten_name
        original_name_inplace = original_name + "_"
        expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype

        for a_op in op.aliases:
            inplace = a_op.inplace_variant
            method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
            variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)

            # Test scripting:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                if variant in method_or_inplace:
                    fn_template = '''
                        def _fn(t0{c}):
                            return t0.{alias_name}({args_kw})
                    '''
                    # remove the first input tensor
                    script = fn_template.format(
                        c=", " if len(args_kw[1:]) > 1 else "",
                        args_kw=", ".join(args_kw[1:]),
                        alias_name=variant_name,
                    )
                else:
                    fn_template = '''
                        def _fn({args}):
                            return variant({args_kw})
                    '''
                    script = fn_template.format(
                        args=", ".join(args),
                        args_kw=", ".join(args_kw),
                    )
                scripted = torch.jit.CompilationUnit(script)._fn

                if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
                    try:
                        inp = clone_input_helper(sample.input)
                        scripted(inp)
                    except Exception as e:
                        continue
                    self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")

                inp = clone_input_helper(sample.input)
                scripted(inp)
                inp = clone_input_helper(sample.input)
                graph = scripted.graph_for(inp)
                FileCheck().check(op.aten_name).check_not(variant_name).run(graph)

            # Test tracing:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                def _fn(*sample_args, **sample_kwargs):
                    return variant(*sample_args, **sample_kwargs)

                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
                traced = torch.jit.trace(_fn, *inp)
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
                traced(*inp)
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
                graph = traced.graph_for(*inp)
                FileCheck().check(op_name).check_not(variant_name).run(graph)
예제 #9
0
    def test_jit_alias_remapping(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        # NOTE: only tests on first sample
        sample = samples[0]

        # Prepare data for test scripting
        args = [f"t{i}" for i in range(len(sample.input))] + \
               [f"s{i}" for i in range(len(sample.args))]
        args_annot_kw = args + [f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()]
        args_kw = args + [f"{k}={v}" for k, v in sample.kwargs.items()]

        # Prepare data for test tracing
        sample_args_kwargs = ()
        if len(sample.args) > 0:
            sample_args_kwargs += (sample.args, )
        if len(sample.kwargs) > 0:
            sample_args_kwargs += (sample.kwargs, )

        original_name = op.name
        original_name_inplace = original_name + "_"
        expected_dtype = op(*sample.input, *sample.args, **sample.kwargs).dtype

        for a_op in op.aliases:  
            inplace = a_op.inplace_variant
            method_or_inplace = [a_op.inplace_variant, a_op.method_variant]            
            variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)

            # Test scripting:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                if variant in method_or_inplace:
                    fn_template = '''
                    def _fn(t0{c}{args_annot_kw}):
                        return t0.{alias_name}({args_kw})
                    '''
                    # remove the first input tensor
                    script = fn_template.format(
                        c=", " if len(args_kw[1:]) > 1 else "",
                        args_annot_kw=", ".join(args_annot_kw[1:]),
                        args_kw=", ".join(args_kw[1:]),
                        alias_name=variant_name,
                    )
                else:
                    fn_template = '''
                        def _fn({args_annot_kw}):
                            return variant({args_kw})
                    '''
                    script = fn_template.format(
                        args_annot_kw=", ".join(args_annot_kw),
                        args_kw=", ".join(args_kw),
                    )
                scripted = torch.jit.CompilationUnit(script)._fn

                if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
                    try:
                        inp = (clone_input_helper(input) for input in sample.input)
                        scripted(*inp, *sample.args, **sample.kwargs)
                    except Exception as e:
                        continue
                    self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")

                inp = (clone_input_helper(input) for input in sample.input)
                scripted(*inp, *sample.args, **sample.kwargs)
                inp = (clone_input_helper(input) for input in sample.input)
                graph = scripted.graph_for(*inp, *sample.args, **sample.kwargs)
                FileCheck().check(op_name).check_not(variant_name).run(graph)

            # Test tracing:
            for variant in variants:
                variant_name = variant.__name__
                op_name = original_name_inplace if variant is inplace else original_name

                def _fn(*sample_args, **sample_kwargs):
                    return variant(*sample_args, **sample_kwargs)

                inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
                traced = torch.jit.trace(_fn, *inp)
                inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
                traced(*inp)
                inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
                graph = traced.graph_for(*inp)
                FileCheck().check(op_name).check_not(variant_name).run(graph)