def test_result_type(self, device, dtypes): "Test result_type for tensor vs tensor and scalar vs scalar." def _get_dtype(x): "Get the dtype of x if x is a tensor. If x is a scalar, get its corresponding dtype if it were a tensor." if torch.is_tensor(x): return x.dtype elif isinstance(x, bool): return torch.bool elif isinstance(x, int): return torch.int64 elif isinstance(x, float): return torch.float32 elif isinstance(x, complex): return torch.complex64 else: raise AssertionError(f"Unkonwn type {x}") # tensor against tensor a_tensor = torch.tensor((0, 1), device=device, dtype=dtypes[0]) a_single_tensor = torch.tensor(1, device=device, dtype=dtypes[0]) a_scalar = a_single_tensor.item() b_tensor = torch.tensor((1, 0), device=device, dtype=dtypes[1]) b_single_tensor = torch.tensor(1, device=device, dtype=dtypes[1]) b_scalar = b_single_tensor.item() combo = ((a_tensor, a_single_tensor, a_scalar), (b_tensor, b_single_tensor, b_scalar)) for a, b in itertools.product(*combo): dtype_a = _get_dtype(a) dtype_b = _get_dtype(b) try: result = a + b except RuntimeError: with self.assertRaises(RuntimeError): torch.promote_types(dtype_a, dtype_b) with self.assertRaises(RuntimeError): torch.result_type(a, b) else: dtype_res = _get_dtype(result) if a is a_scalar and b is b_scalar and dtype_a == torch.bool and dtype_b == torch.bool: # special case: in Python, True + True is an integer self.assertEqual(dtype_res, torch.int64, f"a == {a}, b == {b}") else: self.assertEqual(dtype_res, torch.result_type(a, b), f"a == {a}, b == {b}") if a is a_scalar and b is b_scalar: # Python internal type determination is good enough in this case continue if any(a is a0 and b is b0 for a0, b0 in zip( *combo)): # a and b belong to the same class self.assertEqual(dtype_res, torch.promote_types(dtype_a, dtype_b), f"a == {a}, b == {b}")
def test_addr_type_promotion(self, device, dtypes): a = torch.randn(5).to(device=device, dtype=dtypes[0]) b = torch.randn(5).to(device=device, dtype=dtypes[1]) m = torch.randn(5, 5).to(device=device, dtype=torch.result_type(a, b)) for op in (torch.addr, torch.Tensor.addr): # pass the integer 1 to the torch.result_type as both # the default values of alpha and beta are integers (alpha=1, beta=1) desired_dtype = torch.result_type(m, 1) result = op(m, a, b) self.assertEqual(result.dtype, desired_dtype) desired_dtype = torch.result_type(m, 2.) result = op(m, a, b, beta=0, alpha=2.) self.assertEqual(result.dtype, desired_dtype)
def test_outer_type_promotion(self, device, dtypes): a = torch.randn(5).to(device=device, dtype=dtypes[0]) b = torch.randn(5).to(device=device, dtype=dtypes[1]) for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger): result = op(a, b) self.assertEqual(result.dtype, torch.result_type(a, b))
def forward(self, condition: torch.Tensor, X: torch.Tensor, Y=torch.Tensor): res_type = torch.result_type(X, Y) output = torch.where(condition, X.to(res_type), Y.to(res_type)) return output
def forward(self, input, other): res_type = torch.result_type(input, other) true_quotient = torch.true_divide(input, other) if res_type.is_floating_point: res = true_quotient else: res = torch.floor(true_quotient).to(res_type) return res
def test_result_type(self, device): self.assertEqual( torch.result_type(torch.tensor(1, dtype=torch.int, device=device), 1), torch.int) self.assertEqual( torch.result_type(1, torch.tensor(1, dtype=torch.int, device=device)), torch.int) self.assertEqual(torch.result_type(1, 1.), torch.get_default_dtype()) self.assertEqual(torch.result_type(torch.tensor(1, device=device), 1.), torch.get_default_dtype()) self.assertEqual( torch.result_type( torch.tensor(1, dtype=torch.long, device=device), torch.tensor([1, 1], dtype=torch.int, device=device)), torch.int) self.assertEqual( torch.result_type( torch.tensor([1., 1.], dtype=torch.float, device=device), 1.), torch.float) self.assertEqual( torch.result_type( torch.tensor(1., dtype=torch.float, device=device), torch.tensor(1, dtype=torch.double, device=device)), torch.double)
def get_binary_float_result_type(x, y): dtype1 = x.dtype dtype2 = y.dtype if is_float(dtype1) and is_float(dtype2): return torch.result_type(x, y) elif is_float(dtype1) and is_int(dtype2): return dtype1 elif is_int(dtype1) and is_float(dtype2): return dtype2 elif is_int(dtype1) and is_int(dtype2): return default_float
def test_cat_different_dtypes(self, device): dtypes = torch.testing.get_all_dtypes(include_bfloat16=False) for x_dtype, y_dtype in itertools.product(dtypes, dtypes): x_vals, y_vals = [1, 2, 3], [4, 5, 6] x = torch.tensor(x_vals, device=device, dtype=x_dtype) y = torch.tensor(y_vals, device=device, dtype=y_dtype) if x_dtype is torch.bool: x_vals = [1, 1, 1] if y_dtype is torch.bool: y_vals = [1, 1, 1] res_dtype = torch.result_type(x, y) expected_res = torch.tensor(x_vals + y_vals, device=device, dtype=res_dtype) res = torch.cat([x, y]) self.assertEqual(res, expected_res, exact_dtype=True)
def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device, coalesced): if dtype1.is_complex or dtype2.is_complex: return suffix = '_' if inplace else '' err = "{} {}({}, {})".format( " coalesced" if coalesced else "uncoalesced", op_name + suffix, dtype1, dtype2) def op(t1, t2): return getattr(t1, op_name + suffix)(t2) add_sub = op_name == 'add' or op_name == 'sub' (dense1, sparse1) = self._test_sparse_op_input_tensors(device, dtype1, coalesced) (dense2, sparse2) = self._test_sparse_op_input_tensors(device, dtype2, coalesced, op_name != 'div') common_dtype = torch.result_type(dense1, dense2) if self.device_type == 'cpu' and common_dtype == torch.half: self.assertRaises(RuntimeError, lambda: op(s1, d2)) # Skip inplace tests that would fail due to inability to cast to the output type. # Some of these would also raise errors due to not being a supported op. if inplace and not torch.can_cast(common_dtype, dtype1): self.assertRaises(RuntimeError, lambda: op(dense1, sparse2)) self.assertRaises(RuntimeError, lambda: op(sparse1, sparse2)) self.assertRaises(RuntimeError, lambda: op(sparse1, dense2)) return expected = op(dense1.clone(), dense2) precision = self._get_precision(expected.dtype, coalesced) test_tensors = [expected, dense1, sparse1, dense2, sparse2] e, d1, s1, d2, s2 = [x.clone() for x in test_tensors ] if inplace else test_tensors # Test op(sparse, sparse) if op_name != 'div': sparse = op(s1, s2) self.assertEqual(sparse.dtype, e.dtype) self.assertEqual(e, sparse.to_dense(), atol=precision, message=err) else: # sparse division only supports division by a scalar self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense()) # Test op(dense, sparse) if add_sub: if inplace: e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] dense_sparse = op(d1, s2) self.assertEqual(e, dense_sparse, atol=precision, message=err) else: # sparse division only supports division by a scalar # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz' self.assertRaises(RuntimeError, lambda: op(d1, s2)) # Test op(sparse, dense) not supported for any ops: # add(sparse, dense) is not supported. Use add(dense, sparse) instead. # sparse division only supports division by a scalar # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'. self.assertRaises(RuntimeError, lambda: op(s1, d2)) # Test op(sparse, scalar) if not add_sub and not (self.device_type == 'cpu' and dtype1 == torch.half): if inplace: e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] scalar = d2.view(d2.numel())[0].item() sparse = op(s1, scalar) dense_scalar = op(d1, scalar) self.assertEqual(sparse.dtype, dense_scalar.dtype) self.assertEqual(dense_scalar, sparse.to_dense(), atol=precision, message=err) else: # add(sparse, dense) is not supported. Use add(dense, sparse) instead. # "mul_cpu" / "div_cpu" not implemented for 'Half' self.assertRaises(RuntimeError, lambda: op(s1, d2.view(d2.numel())[0].item()))
def __rpow__(self, other): dtype = torch.result_type(other, self) return torch.tensor(other, dtype=dtype, device=self.device)**self
def _test_spot(a, b, res_dtype): self.assertEqual(torch.result_type(a, b), res_dtype) self.assertEqual(torch.result_type(b, a), res_dtype)
def get_result(x_1, x_2): args = [x_1, x_2] required_type = torch.result_type(*args) return list(filter(lambda arg: arg.dtype == required_type, args))[0]