def test_variant_consistency_eager(self, device, dtype, op): test_backward = op.test_complex_grad or not dtype.is_complex samples = op.sample_inputs(device, dtype, requires_grad=test_backward) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") for sample in samples: # Acquires variants to test method = op.get_method() inplace = op.get_inplace() variants = (v for v in (method, inplace) if v is not None) # Computes expected forward # below calls op's function variant expected_forward = op(*sample.input, *sample.args, **sample.kwargs) # Computes expected backward # NOTE: backward may fail for some dtypes exception_during_backwards = False expected_grad = None try: expected_forward.sum().backward() expected_grad = sample.input.grad sample.input.grad = None except Exception as e: exception_during_backwards = True # Test eager consistency for variant in variants: # Verifies that inplace operations that promote int->float fail # on tensors with integer dtypes. if (variant is inplace and not torch.can_cast(expected_forward.dtype, dtype)): try: variant_forward = variant( *(clone_input_helper(input) for input in sample.input), *sample.args, **sample.kwargs) except Exception as e: continue self.fail( "Inplace operation on integer tensor that should be promoted to float didn't fail!" ) # Compares variant's forward # Note: copy the tensor-type inputs when testing inplace operation variant_forward = variant( *(clone_input_helper(input) if variant is inplace else input for input in sample.input), *sample.args, **sample.kwargs) self.assertEqual(variant_forward, expected_forward) # Compares variant's backward if test_backward and (variant is not inplace or op.test_inplace_grad): self.check_variant_backward(sample.input, variant_forward, expected_grad, exception_during_backwards)
def can_cast(from_, to): """Wrapper of `torch.can_cast`. Parameters ---------- from_ : data-type Data type to cast from. to : data-type Data type to cast to. """ return torch.can_cast(from_, to)
def test_unary_op_out_casting(self, device, dtypes): t = torch.tensor((1), dtype=dtypes[0], device=device) out = torch.empty(1, dtype=dtypes[1], device=device) ops = (torch.neg, ) for op in ops: if torch.can_cast(dtypes[0], dtypes[1]): self.assertEqual(op(t, out=out), out) else: with self.assertRaisesRegex(RuntimeError, 'can\'t be cast'): op(t, out=out)
def forward(self, input: torch.Tensor) -> torch.Tensor: # To make it compatible with torchscript since torch.Size does not work t_shape = torch.tensor(input.shape) if self._shape is not None and list(t_shape) != list(self._shape): if self._broadcastable: if not self._broadcast(t_shape, self._shape): raise ValueError( f'Shapes {self._shape} and {input.shape} are non' ' broadcastable') else: raise ValueError( f'Expected {self._shape}, input shape is {input.shape}') if self._dtype is not None and input.dtype != self._dtype: if self._can_cast: if not torch.can_cast(input.dtype, self._dtype): raise ValueError( f'Input dtype {input.dtype} can\'t be casted to' f' {self._dtype}') else: raise ValueError( f'Expected {self._dtype}, input dtype is {input.dtype}') return input
def test_jit_alias_remapping(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] # [Scripting Data Preparation] # Prepare data for test scripting # Below we prepare strings of args/kwargs with and without type annotations. # These strings are inserted into function template strings which is then torch scripted. # - args string is ["t0"] corresponding to the "input" tensor required by the op # - args_annot_kw is the string for the template function signature, for example, # ["t0", "s0: float", "s1: bool", "max: float = 1.0", "min: float = 0.0"] -> # def fn(t0, s0: float, s1: bool, max: float = 1.0, min: float = 0.0) # - args_kw is the string of args/kwargs used to call the op, same as args_annot_kw but # without type annotations args = ["t0"] args_annot_kw = args + \ [f"s{i}: {type(v).__name__}" for i, v in enumerate(sample.args)] + \ [f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()] args_kw = args + \ [f"s{i}" for i in range(len(sample.args))] + \ [f"{k}={v}" for k, v in sample.kwargs.items()] # Prepare data for test tracing sample_args_kwargs = () if len(sample.args) > 0: sample_args_kwargs += (sample.args, ) if len(sample.kwargs) > 0: sample_args_kwargs += (sample.kwargs, ) original_name = op.aten_name original_name_inplace = original_name + "_" expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype for a_op in op.aliases: inplace = a_op.inplace_variant method_or_inplace = [a_op.inplace_variant, a_op.method_variant] variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None) # Test scripting: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name if variant in method_or_inplace: fn_template = ''' def _fn(t0{c}{args_annot_kw}): return t0.{alias_name}({args_kw}) ''' # remove the first input tensor script = fn_template.format( c=", " if len(args_kw[1:]) > 1 or len(args_annot_kw[1:]) >= 1 else "", args_annot_kw=", ".join(args_annot_kw[1:]), args_kw=", ".join(args_kw[1:]), alias_name=variant_name, ) else: fn_template = ''' def _fn({args_annot_kw}): return variant({args_kw}) ''' script = fn_template.format( args_annot_kw=", ".join(args_annot_kw), args_kw=", ".join(args_kw), ) scripted = torch.jit.CompilationUnit(script)._fn if (variant is inplace and not torch.can_cast(expected_dtype, dtype)): try: inp = clone_input_helper(sample.input) scripted(inp, *sample.args, **sample.kwargs) except Exception as e: continue self.fail( "Inplace operation on integer tensor that should be promoted to float didn't fail!" ) inp = clone_input_helper(sample.input) scripted(inp, *sample.args, **sample.kwargs) inp = clone_input_helper(sample.input) graph = scripted.graph_for(inp, *sample.args, **sample.kwargs) FileCheck().check( op.aten_name).check_not(variant_name).run(graph) # Test tracing: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name def _fn(*sample_args, **sample_kwargs): return variant(*sample_args, **sample_kwargs) inp = (clone_input_helper(sample.input), ) + sample_args_kwargs traced = torch.jit.trace(_fn, *inp) inp = (clone_input_helper(sample.input), ) + sample_args_kwargs traced(*inp) inp = (clone_input_helper(sample.input), ) + sample_args_kwargs graph = traced.graph_for(*inp) FileCheck().check(op_name).check_not(variant_name).run(graph)
def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device, coalesced): if dtype1.is_complex or dtype2.is_complex: return suffix = '_' if inplace else '' err = "{} {}({}, {})".format( " coalesced" if coalesced else "uncoalesced", op_name + suffix, dtype1, dtype2) def op(t1, t2): return getattr(t1, op_name + suffix)(t2) add_sub = op_name == 'add' or op_name == 'sub' (dense1, sparse1) = self._test_sparse_op_input_tensors(device, dtype1, coalesced) (dense2, sparse2) = self._test_sparse_op_input_tensors(device, dtype2, coalesced, op_name != 'div') common_dtype = torch.result_type(dense1, dense2) if self.device_type == 'cpu' and common_dtype == torch.half: self.assertRaises(RuntimeError, lambda: op(s1, d2)) # Skip inplace tests that would fail due to inability to cast to the output type. # Some of these would also raise errors due to not being a supported op. if inplace and not torch.can_cast(common_dtype, dtype1): self.assertRaises(RuntimeError, lambda: op(dense1, sparse2)) self.assertRaises(RuntimeError, lambda: op(sparse1, sparse2)) self.assertRaises(RuntimeError, lambda: op(sparse1, dense2)) return expected = op(dense1.clone(), dense2) precision = self._get_precision(expected.dtype, coalesced) test_tensors = [expected, dense1, sparse1, dense2, sparse2] e, d1, s1, d2, s2 = [x.clone() for x in test_tensors ] if inplace else test_tensors # Test op(sparse, sparse) if op_name != 'div': sparse = op(s1, s2) self.assertEqual(sparse.dtype, e.dtype) self.assertEqual(e, sparse.to_dense(), atol=precision, message=err) else: # sparse division only supports division by a scalar self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense()) # Test op(dense, sparse) if add_sub: if inplace: e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] dense_sparse = op(d1, s2) self.assertEqual(e, dense_sparse, atol=precision, message=err) else: # sparse division only supports division by a scalar # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz' self.assertRaises(RuntimeError, lambda: op(d1, s2)) # Test op(sparse, dense) not supported for any ops: # add(sparse, dense) is not supported. Use add(dense, sparse) instead. # sparse division only supports division by a scalar # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'. self.assertRaises(RuntimeError, lambda: op(s1, d2)) # Test op(sparse, scalar) if not add_sub and not (self.device_type == 'cpu' and dtype1 == torch.half): if inplace: e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] scalar = d2.view(d2.numel())[0].item() sparse = op(s1, scalar) dense_scalar = op(d1, scalar) self.assertEqual(sparse.dtype, dense_scalar.dtype) self.assertEqual(dense_scalar, sparse.to_dense(), atol=precision, message=err) else: # add(sparse, dense) is not supported. Use add(dense, sparse) instead. # "mul_cpu" / "div_cpu" not implemented for 'Half' self.assertRaises(RuntimeError, lambda: op(s1, d2.view(d2.numel())[0].item()))
def test_can_cast(self, device): self.assertTrue(torch.can_cast(torch.double, torch.float)) self.assertFalse(torch.can_cast(torch.float, torch.int))
def test_jit_alias_remapping(self, device, dtype, op): # Required to avoid undefined value: tensor error in JIT compilation of the function template tensor = torch.tensor samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] # [Scripting Data Preparation] # Prepare data for test scripting # Below we prepare strings of args/kwargs with and without type annotations. # These strings are inserted into function template strings which is then torch scripted. # - args string is ["t0"] corresponding to the "input" tensor required by the op # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example, # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0)) args = ["t0"] def quote_strs(v): if isinstance(v, str): return f"'{v}'" return str(v) args_kw = args + \ [f"{v}" for v in sample.args] + \ [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()] # Prepare data for test tracing sample_args_kwargs = () if len(sample.args) > 0: sample_args_kwargs += (sample.args, ) if len(sample.kwargs) > 0: sample_args_kwargs += (sample.kwargs, ) original_name = op.aten_name original_name_inplace = original_name + "_" expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype for a_op in op.aliases: inplace = a_op.inplace_variant method_or_inplace = [a_op.inplace_variant, a_op.method_variant] variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None) # Test scripting: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name if variant in method_or_inplace: fn_template = ''' def _fn(t0{c}): return t0.{alias_name}({args_kw}) ''' # remove the first input tensor script = fn_template.format( c=", " if len(args_kw[1:]) > 1 else "", args_kw=", ".join(args_kw[1:]), alias_name=variant_name, ) else: fn_template = ''' def _fn({args}): return variant({args_kw}) ''' script = fn_template.format( args=", ".join(args), args_kw=", ".join(args_kw), ) scripted = torch.jit.CompilationUnit(script)._fn if (variant is inplace and not torch.can_cast(expected_dtype, dtype)): try: inp = clone_input_helper(sample.input) scripted(inp) except Exception as e: continue self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!") inp = clone_input_helper(sample.input) scripted(inp) inp = clone_input_helper(sample.input) graph = scripted.graph_for(inp) FileCheck().check(op.aten_name).check_not(variant_name).run(graph) # Test tracing: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name def _fn(*sample_args, **sample_kwargs): return variant(*sample_args, **sample_kwargs) inp = (clone_input_helper(sample.input),) + sample_args_kwargs traced = torch.jit.trace(_fn, *inp) inp = (clone_input_helper(sample.input),) + sample_args_kwargs traced(*inp) inp = (clone_input_helper(sample.input),) + sample_args_kwargs graph = traced.graph_for(*inp) FileCheck().check(op_name).check_not(variant_name).run(graph)
def test_jit_alias_remapping(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] # Prepare data for test scripting args = [f"t{i}" for i in range(len(sample.input))] + \ [f"s{i}" for i in range(len(sample.args))] args_annot_kw = args + [f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()] args_kw = args + [f"{k}={v}" for k, v in sample.kwargs.items()] # Prepare data for test tracing sample_args_kwargs = () if len(sample.args) > 0: sample_args_kwargs += (sample.args, ) if len(sample.kwargs) > 0: sample_args_kwargs += (sample.kwargs, ) original_name = op.name original_name_inplace = original_name + "_" expected_dtype = op(*sample.input, *sample.args, **sample.kwargs).dtype for a_op in op.aliases: inplace = a_op.inplace_variant method_or_inplace = [a_op.inplace_variant, a_op.method_variant] variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None) # Test scripting: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name if variant in method_or_inplace: fn_template = ''' def _fn(t0{c}{args_annot_kw}): return t0.{alias_name}({args_kw}) ''' # remove the first input tensor script = fn_template.format( c=", " if len(args_kw[1:]) > 1 else "", args_annot_kw=", ".join(args_annot_kw[1:]), args_kw=", ".join(args_kw[1:]), alias_name=variant_name, ) else: fn_template = ''' def _fn({args_annot_kw}): return variant({args_kw}) ''' script = fn_template.format( args_annot_kw=", ".join(args_annot_kw), args_kw=", ".join(args_kw), ) scripted = torch.jit.CompilationUnit(script)._fn if (variant is inplace and not torch.can_cast(expected_dtype, dtype)): try: inp = (clone_input_helper(input) for input in sample.input) scripted(*inp, *sample.args, **sample.kwargs) except Exception as e: continue self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!") inp = (clone_input_helper(input) for input in sample.input) scripted(*inp, *sample.args, **sample.kwargs) inp = (clone_input_helper(input) for input in sample.input) graph = scripted.graph_for(*inp, *sample.args, **sample.kwargs) FileCheck().check(op_name).check_not(variant_name).run(graph) # Test tracing: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name def _fn(*sample_args, **sample_kwargs): return variant(*sample_args, **sample_kwargs) inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs traced = torch.jit.trace(_fn, *inp) inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs traced(*inp) inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs graph = traced.graph_for(*inp) FileCheck().check(op_name).check_not(variant_name).run(graph)