Пример #1
0
    def test_result_type(self, device, dtypes):
        "Test result_type for tensor vs tensor and scalar vs scalar."

        def _get_dtype(x):
            "Get the dtype of x if x is a tensor. If x is a scalar, get its corresponding dtype if it were a tensor."
            if torch.is_tensor(x):
                return x.dtype
            elif isinstance(x, bool):
                return torch.bool
            elif isinstance(x, int):
                return torch.int64
            elif isinstance(x, float):
                return torch.float32
            elif isinstance(x, complex):
                return torch.complex64
            else:
                raise AssertionError(f"Unkonwn type {x}")

        # tensor against tensor
        a_tensor = torch.tensor((0, 1), device=device, dtype=dtypes[0])
        a_single_tensor = torch.tensor(1, device=device, dtype=dtypes[0])
        a_scalar = a_single_tensor.item()
        b_tensor = torch.tensor((1, 0), device=device, dtype=dtypes[1])
        b_single_tensor = torch.tensor(1, device=device, dtype=dtypes[1])
        b_scalar = b_single_tensor.item()
        combo = ((a_tensor, a_single_tensor, a_scalar),
                 (b_tensor, b_single_tensor, b_scalar))
        for a, b in itertools.product(*combo):
            dtype_a = _get_dtype(a)
            dtype_b = _get_dtype(b)
            try:
                result = a + b
            except RuntimeError:
                with self.assertRaises(RuntimeError):
                    torch.promote_types(dtype_a, dtype_b)
                with self.assertRaises(RuntimeError):
                    torch.result_type(a, b)
            else:
                dtype_res = _get_dtype(result)
                if a is a_scalar and b is b_scalar and dtype_a == torch.bool and dtype_b == torch.bool:
                    # special case: in Python, True + True is an integer
                    self.assertEqual(dtype_res, torch.int64,
                                     f"a == {a}, b == {b}")
                else:
                    self.assertEqual(dtype_res, torch.result_type(a, b),
                                     f"a == {a}, b == {b}")
                if a is a_scalar and b is b_scalar:  # Python internal type determination is good enough in this case
                    continue
                if any(a is a0 and b is b0 for a0, b0 in zip(
                        *combo)):  # a and b belong to the same class
                    self.assertEqual(dtype_res,
                                     torch.promote_types(dtype_a, dtype_b),
                                     f"a == {a}, b == {b}")
Пример #2
0
    def test_addr_type_promotion(self, device, dtypes):
        a = torch.randn(5).to(device=device, dtype=dtypes[0])
        b = torch.randn(5).to(device=device, dtype=dtypes[1])
        m = torch.randn(5, 5).to(device=device, dtype=torch.result_type(a, b))
        for op in (torch.addr, torch.Tensor.addr):
            # pass the integer 1 to the torch.result_type as both
            # the default values of alpha and beta are integers (alpha=1, beta=1)
            desired_dtype = torch.result_type(m, 1)
            result = op(m, a, b)
            self.assertEqual(result.dtype, desired_dtype)

            desired_dtype = torch.result_type(m, 2.)
            result = op(m, a, b, beta=0, alpha=2.)
            self.assertEqual(result.dtype, desired_dtype)
Пример #3
0
 def test_outer_type_promotion(self, device, dtypes):
     a = torch.randn(5).to(device=device, dtype=dtypes[0])
     b = torch.randn(5).to(device=device, dtype=dtypes[1])
     for op in (torch.outer, torch.Tensor.outer, torch.ger,
                torch.Tensor.ger):
         result = op(a, b)
         self.assertEqual(result.dtype, torch.result_type(a, b))
Пример #4
0
 def forward(self,
             condition: torch.Tensor,
             X: torch.Tensor,
             Y=torch.Tensor):
     res_type = torch.result_type(X, Y)
     output = torch.where(condition, X.to(res_type), Y.to(res_type))
     return output
Пример #5
0
 def forward(self, input, other):
     res_type = torch.result_type(input, other)
     true_quotient = torch.true_divide(input, other)
     if res_type.is_floating_point:
         res = true_quotient
     else:
         res = torch.floor(true_quotient).to(res_type)
     return res
Пример #6
0
 def test_result_type(self, device):
     self.assertEqual(
         torch.result_type(torch.tensor(1, dtype=torch.int, device=device),
                           1), torch.int)
     self.assertEqual(
         torch.result_type(1, torch.tensor(1,
                                           dtype=torch.int,
                                           device=device)), torch.int)
     self.assertEqual(torch.result_type(1, 1.), torch.get_default_dtype())
     self.assertEqual(torch.result_type(torch.tensor(1, device=device), 1.),
                      torch.get_default_dtype())
     self.assertEqual(
         torch.result_type(
             torch.tensor(1, dtype=torch.long, device=device),
             torch.tensor([1, 1], dtype=torch.int, device=device)),
         torch.int)
     self.assertEqual(
         torch.result_type(
             torch.tensor([1., 1.], dtype=torch.float, device=device), 1.),
         torch.float)
     self.assertEqual(
         torch.result_type(
             torch.tensor(1., dtype=torch.float, device=device),
             torch.tensor(1, dtype=torch.double, device=device)),
         torch.double)
Пример #7
0
 def get_binary_float_result_type(x, y):
     dtype1 = x.dtype
     dtype2 = y.dtype
     if is_float(dtype1) and is_float(dtype2):
         return torch.result_type(x, y)
     elif is_float(dtype1) and is_int(dtype2):
         return dtype1
     elif is_int(dtype1) and is_float(dtype2):
         return dtype2
     elif is_int(dtype1) and is_int(dtype2):
         return default_float
Пример #8
0
    def test_cat_different_dtypes(self, device):
        dtypes = torch.testing.get_all_dtypes(include_bfloat16=False)
        for x_dtype, y_dtype in itertools.product(dtypes, dtypes):
            x_vals, y_vals = [1, 2, 3], [4, 5, 6]

            x = torch.tensor(x_vals, device=device, dtype=x_dtype)
            y = torch.tensor(y_vals, device=device, dtype=y_dtype)

            if x_dtype is torch.bool:
                x_vals = [1, 1, 1]
            if y_dtype is torch.bool:
                y_vals = [1, 1, 1]

            res_dtype = torch.result_type(x, y)
            expected_res = torch.tensor(x_vals + y_vals, device=device, dtype=res_dtype)
            res = torch.cat([x, y])
            self.assertEqual(res, expected_res, exact_dtype=True)
Пример #9
0
    def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device,
                        coalesced):
        if dtype1.is_complex or dtype2.is_complex:
            return

        suffix = '_' if inplace else ''
        err = "{} {}({}, {})".format(
            "  coalesced" if coalesced else "uncoalesced", op_name + suffix,
            dtype1, dtype2)

        def op(t1, t2):
            return getattr(t1, op_name + suffix)(t2)

        add_sub = op_name == 'add' or op_name == 'sub'

        (dense1,
         sparse1) = self._test_sparse_op_input_tensors(device, dtype1,
                                                       coalesced)
        (dense2,
         sparse2) = self._test_sparse_op_input_tensors(device, dtype2,
                                                       coalesced,
                                                       op_name != 'div')

        common_dtype = torch.result_type(dense1, dense2)
        if self.device_type == 'cpu' and common_dtype == torch.half:
            self.assertRaises(RuntimeError, lambda: op(s1, d2))

        # Skip inplace tests that would fail due to inability to cast to the output type.
        # Some of these would also raise errors due to not being a supported op.
        if inplace and not torch.can_cast(common_dtype, dtype1):
            self.assertRaises(RuntimeError, lambda: op(dense1, sparse2))
            self.assertRaises(RuntimeError, lambda: op(sparse1, sparse2))
            self.assertRaises(RuntimeError, lambda: op(sparse1, dense2))
            return

        expected = op(dense1.clone(), dense2)
        precision = self._get_precision(expected.dtype, coalesced)
        test_tensors = [expected, dense1, sparse1, dense2, sparse2]
        e, d1, s1, d2, s2 = [x.clone() for x in test_tensors
                             ] if inplace else test_tensors

        # Test op(sparse, sparse)
        if op_name != 'div':
            sparse = op(s1, s2)
            self.assertEqual(sparse.dtype, e.dtype)
            self.assertEqual(e, sparse.to_dense(), atol=precision, message=err)
        else:
            # sparse division only supports division by a scalar
            self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense())

        # Test op(dense, sparse)
        if add_sub:
            if inplace:
                e, d1, s1, d2, s2 = [x.clone() for x in test_tensors]
            dense_sparse = op(d1, s2)
            self.assertEqual(e, dense_sparse, atol=precision, message=err)
        else:
            # sparse division only supports division by a scalar
            # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'
            self.assertRaises(RuntimeError, lambda: op(d1, s2))

        # Test op(sparse, dense) not supported for any ops:
        # add(sparse, dense) is not supported. Use add(dense, sparse) instead.
        # sparse division only supports division by a scalar
        # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'.
        self.assertRaises(RuntimeError, lambda: op(s1, d2))

        # Test op(sparse, scalar)
        if not add_sub and not (self.device_type == 'cpu'
                                and dtype1 == torch.half):
            if inplace:
                e, d1, s1, d2, s2 = [x.clone() for x in test_tensors]
            scalar = d2.view(d2.numel())[0].item()

            sparse = op(s1, scalar)
            dense_scalar = op(d1, scalar)
            self.assertEqual(sparse.dtype, dense_scalar.dtype)
            self.assertEqual(dense_scalar,
                             sparse.to_dense(),
                             atol=precision,
                             message=err)
        else:
            # add(sparse, dense) is not supported. Use add(dense, sparse) instead.
            # "mul_cpu" / "div_cpu" not implemented for 'Half'
            self.assertRaises(RuntimeError,
                              lambda: op(s1,
                                         d2.view(d2.numel())[0].item()))
Пример #10
0
 def __rpow__(self, other):
     dtype = torch.result_type(other, self)
     return torch.tensor(other, dtype=dtype, device=self.device)**self
Пример #11
0
 def _test_spot(a, b, res_dtype):
     self.assertEqual(torch.result_type(a, b), res_dtype)
     self.assertEqual(torch.result_type(b, a), res_dtype)
Пример #12
0
 def get_result(x_1, x_2):
     args = [x_1, x_2]
     required_type = torch.result_type(*args)
     return list(filter(lambda arg: arg.dtype == required_type, args))[0]