def test_variant_consistency_eager(self, device, dtype, op): test_backward = op.test_complex_grad or not dtype.is_complex samples = op.sample_inputs(device, dtype, requires_grad=test_backward) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") for sample in samples: # Acquires variants to test method = op.get_method() inplace = op.get_inplace() variants = (v for v in (method, inplace) if v is not None) # Computes expected forward # below calls op's function variant expected_forward = op(*sample.input, *sample.args, **sample.kwargs) # Computes expected backward # NOTE: backward may fail for some dtypes exception_during_backwards = False expected_grad = None try: expected_forward.sum().backward() expected_grad = sample.input.grad sample.input.grad = None except Exception as e: exception_during_backwards = True # Test eager consistency for variant in variants: # Verifies that inplace operations that promote int->float fail # on tensors with integer dtypes. if (variant is inplace and not torch.can_cast(expected_forward.dtype, dtype)): try: variant_forward = variant( *(clone_input_helper(input) for input in sample.input), *sample.args, **sample.kwargs) except Exception as e: continue self.fail( "Inplace operation on integer tensor that should be promoted to float didn't fail!" ) # Compares variant's forward # Note: copy the tensor-type inputs when testing inplace operation variant_forward = variant( *(clone_input_helper(input) if variant is inplace else input for input in sample.input), *sample.args, **sample.kwargs) self.assertEqual(variant_forward, expected_forward) # Compares variant's backward if test_backward and (variant is not inplace or op.test_inplace_grad): self.check_variant_backward(sample.input, variant_forward, expected_grad, exception_during_backwards)
def _test_inplace_preserve_storage(samples, variants): for sample in samples: # Skips inplace variants if the output dtype is not the same as # the input dtype expected_forward = op(sample.input, *sample.args, **sample.kwargs) tensor = sample.input if isinstance( sample.input, torch.Tensor) else sample.input[0] skip_inplace = False if (isinstance(expected_forward, torch.Tensor) and expected_forward.dtype is not tensor.dtype): skip_inplace = True if skip_inplace: return for variant in variants: cloned = clone_input_helper( sample.input ) if variant in inplace_ops else sample.input inp_tensor = cloned if isinstance( cloned, torch.Tensor) else cloned[0] data_ptr = inp_tensor.data_ptr() variant_forward = variant(cloned, *sample.args, **sample.kwargs) # TODO Support non-tensor outputs if they exist for inplace ops if (isinstance(variant_forward, torch.Tensor)): self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0) else: self.assertTrue( False, "Non-tensor outputs for inplace ops are not supported" )
def _test_consistency_helper(samples, variants): for sample in samples: # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList tensor = sample.input if isinstance( sample.input, torch.Tensor) else sample.input[0] # Computes function forward and backward values tensor.grad = None expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_grad = None # Skips inplace variants if the output dtype is not the same as # the input dtype skip_inplace = False if (isinstance(expected_forward, torch.Tensor) and expected_forward.dtype is not tensor.dtype): skip_inplace = True # TODO: backward consistency only supported for single tensor outputs # TODO: backward consistency only checked on sample.input, not all # tensor inputs # TODO: update to handle checking grads of all tensor inputs as # derived from each tensor output if (op.supports_autograd and isinstance(expected_forward, torch.Tensor) and (dtype.is_floating_point or op.supports_complex_autograd( torch.device(device).type))): expected_forward.sum().backward() expected_grad = tensor.grad # Test eager consistency for variant in variants: # Skips inplace ops if variant in inplace_ops and skip_inplace: continue # Compares variant's forward # Note: copies the to-be-modified input when testing the inplace variant tensor.grad = None cloned = clone_input_helper( sample.input ) if variant in inplace_ops else sample.input if variant in inplace_ops and sample.broadcasts_input: with self.assertRaises(RuntimeError): variant_forward = variant(cloned, *sample.args, **sample.kwargs) continue variant_forward = variant(cloned, *sample.args, **sample.kwargs) self.assertEqual(expected_forward, variant_forward) # Compares variant's backward if expected_grad is not None and \ (variant not in inplace_ops or op.supports_inplace_autograd): variant_forward.sum().backward() self.assertEqual(expected_grad, tensor.grad)
def test_variant_consistency_eager(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) for sample in samples: # Acquires variants (method variant, inplace variant, aliases) method = op.get_method() inplace = op.get_inplace() # list of all inplace ops: inplace variant + alias inplace variants if exist inplace_ops = [ inplace, ] aliases = [] for a_op in op.aliases: aliases.append(a_op.op) aliases.append(a_op.method_variant) aliases.append(a_op.inplace_variant) inplace_ops.append(a_op.inplace_variant) aliases = tuple(aliases) inplace_ops = tuple(v for v in inplace_ops if v is not None) variants = (v for v in (method, inplace) + aliases if v is not None) # Computes function forward and backward values sample.input.grad = None expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_grad = None # TODO: backward consistency only supported for single tensor outputs # TODO: backward consistency only checked on sample.input, not all # tensor inputs # TODO: update to handle checking grads of all tensor inputs as # derived from each tensor output if (op.supports_autograd and isinstance(expected_forward, torch.Tensor)): expected_forward.sum().backward() expected_grad = sample.input.grad # Test eager consistency for variant in variants: # Compares variant's forward # Note: copies the to-be-modified input when testing the inplace variant sample.input.grad = None cloned = clone_input_helper( sample.input) if variant in inplace_ops else sample.input variant_forward = variant(cloned, *sample.args, **sample.kwargs) self.assertEqual(expected_forward, variant_forward) # Compares variant's backward if expected_grad is not None and ( variant not in inplace_ops or op.supports_inplace_autograd): variant_forward.sum().backward() self.assertEqual(expected_grad, sample.input.grad)
def test_jit_alias_remapping(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] # [Scripting Data Preparation] # Prepare data for test scripting # Below we prepare strings of args/kwargs with and without type annotations. # These strings are inserted into function template strings which is then torch scripted. # - args string is ["t0"] corresponding to the "input" tensor required by the op # - args_annot_kw is the string for the template function signature, for example, # ["t0", "s0: float", "s1: bool", "max: float = 1.0", "min: float = 0.0"] -> # def fn(t0, s0: float, s1: bool, max: float = 1.0, min: float = 0.0) # - args_kw is the string of args/kwargs used to call the op, same as args_annot_kw but # without type annotations args = ["t0"] args_annot_kw = args + \ [f"s{i}: {type(v).__name__}" for i, v in enumerate(sample.args)] + \ [f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()] args_kw = args + \ [f"s{i}" for i in range(len(sample.args))] + \ [f"{k}={v}" for k, v in sample.kwargs.items()] # Prepare data for test tracing sample_args_kwargs = () if len(sample.args) > 0: sample_args_kwargs += (sample.args, ) if len(sample.kwargs) > 0: sample_args_kwargs += (sample.kwargs, ) original_name = op.aten_name original_name_inplace = original_name + "_" expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype for a_op in op.aliases: inplace = a_op.inplace_variant method_or_inplace = [a_op.inplace_variant, a_op.method_variant] variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None) # Test scripting: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name if variant in method_or_inplace: fn_template = ''' def _fn(t0{c}{args_annot_kw}): return t0.{alias_name}({args_kw}) ''' # remove the first input tensor script = fn_template.format( c=", " if len(args_kw[1:]) > 1 or len(args_annot_kw[1:]) >= 1 else "", args_annot_kw=", ".join(args_annot_kw[1:]), args_kw=", ".join(args_kw[1:]), alias_name=variant_name, ) else: fn_template = ''' def _fn({args_annot_kw}): return variant({args_kw}) ''' script = fn_template.format( args_annot_kw=", ".join(args_annot_kw), args_kw=", ".join(args_kw), ) scripted = torch.jit.CompilationUnit(script)._fn if (variant is inplace and not torch.can_cast(expected_dtype, dtype)): try: inp = clone_input_helper(sample.input) scripted(inp, *sample.args, **sample.kwargs) except Exception as e: continue self.fail( "Inplace operation on integer tensor that should be promoted to float didn't fail!" ) inp = clone_input_helper(sample.input) scripted(inp, *sample.args, **sample.kwargs) inp = clone_input_helper(sample.input) graph = scripted.graph_for(inp, *sample.args, **sample.kwargs) FileCheck().check( op.aten_name).check_not(variant_name).run(graph) # Test tracing: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name def _fn(*sample_args, **sample_kwargs): return variant(*sample_args, **sample_kwargs) inp = (clone_input_helper(sample.input), ) + sample_args_kwargs traced = torch.jit.trace(_fn, *inp) inp = (clone_input_helper(sample.input), ) + sample_args_kwargs traced(*inp) inp = (clone_input_helper(sample.input), ) + sample_args_kwargs graph = traced.graph_for(*inp) FileCheck().check(op_name).check_not(variant_name).run(graph)
def test_variant_consistency_eager(self, device, dtype, op): # Acquires variants (method variant, inplace variant, aliases) method = op.get_method() inplace = op.get_inplace() # list of all inplace ops: inplace variant + alias inplace variants if exist inplace_ops = [ inplace, ] aliases = [] for a_op in op.aliases: aliases.append(a_op.op) aliases.append(a_op.method_variant) aliases.append(a_op.inplace_variant) inplace_ops.append(a_op.inplace_variant) aliases = tuple(aliases) inplace_ops = tuple(v for v in inplace_ops if v is not None) variants = (v for v in (method, inplace) + aliases if v is not None) _requires_grad = (op.supports_autograd and (dtype.is_floating_point or op.supports_complex_autograd)) samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) for sample in samples: # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList tensor = sample.input if isinstance( sample.input, torch.Tensor) else sample.input[0] # Computes function forward and backward values tensor.grad = None expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_grad = None # Skips inplace variants if the output dtype is not the same as # the input dtype skip_inplace = False if (isinstance(expected_forward, torch.Tensor) and expected_forward.dtype is not tensor.dtype): skip_inplace = True # TODO: backward consistency only supported for single tensor outputs # TODO: backward consistency only checked on first input Tensor # TODO: update to handle checking grads of all tensor inputs as # derived from each tensor output if (op.supports_autograd and isinstance(expected_forward, torch.Tensor) and (dtype.is_floating_point or op.supports_complex_autograd)): expected_forward.sum().backward() expected_grad = tensor.grad # Test eager consistency for variant in variants: # Skips inplace ops if variant in inplace_ops and skip_inplace: continue # Compares variant's forward # Note: copies the to-be-modified input when testing the inplace variant tensor.grad = None cloned = clone_input_helper( sample.input) if variant in inplace_ops else sample.input variant_forward = variant(cloned, *sample.args, **sample.kwargs) self.assertEqual(expected_forward, variant_forward) # Compares variant's backward if expected_grad is not None and ( variant not in inplace_ops or op.supports_inplace_autograd): variant_forward.sum().backward() self.assertEqual(expected_grad, tensor.grad)
def test_jit_alias_remapping(self, device, dtype, op): # Required to avoid undefined value: tensor error in JIT compilation of the function template tensor = torch.tensor samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] # [Scripting Data Preparation] # Prepare data for test scripting # Below we prepare strings of args/kwargs with and without type annotations. # These strings are inserted into function template strings which is then torch scripted. # - args string is ["t0"] corresponding to the "input" tensor required by the op # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example, # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0)) args = ["t0"] def quote_strs(v): if isinstance(v, str): return f"'{v}'" return str(v) args_kw = args + \ [f"{v}" for v in sample.args] + \ [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()] # Prepare data for test tracing sample_args_kwargs = () if len(sample.args) > 0: sample_args_kwargs += (sample.args, ) if len(sample.kwargs) > 0: sample_args_kwargs += (sample.kwargs, ) original_name = op.aten_name original_name_inplace = original_name + "_" expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype for a_op in op.aliases: inplace = a_op.inplace_variant method_or_inplace = [a_op.inplace_variant, a_op.method_variant] variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None) # Test scripting: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name if variant in method_or_inplace: fn_template = ''' def _fn(t0{c}): return t0.{alias_name}({args_kw}) ''' # remove the first input tensor script = fn_template.format( c=", " if len(args_kw[1:]) > 1 else "", args_kw=", ".join(args_kw[1:]), alias_name=variant_name, ) else: fn_template = ''' def _fn({args}): return variant({args_kw}) ''' script = fn_template.format( args=", ".join(args), args_kw=", ".join(args_kw), ) scripted = torch.jit.CompilationUnit(script)._fn if (variant is inplace and not torch.can_cast(expected_dtype, dtype)): try: inp = clone_input_helper(sample.input) scripted(inp) except Exception as e: continue self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!") inp = clone_input_helper(sample.input) scripted(inp) inp = clone_input_helper(sample.input) graph = scripted.graph_for(inp) FileCheck().check(op.aten_name).check_not(variant_name).run(graph) # Test tracing: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name def _fn(*sample_args, **sample_kwargs): return variant(*sample_args, **sample_kwargs) inp = (clone_input_helper(sample.input),) + sample_args_kwargs traced = torch.jit.trace(_fn, *inp) inp = (clone_input_helper(sample.input),) + sample_args_kwargs traced(*inp) inp = (clone_input_helper(sample.input),) + sample_args_kwargs graph = traced.graph_for(*inp) FileCheck().check(op_name).check_not(variant_name).run(graph)
def get_sample(): return clone_input_helper( sample.input ) if op.name[-1] == '_' else sample.input
def test_jit_alias_remapping(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] # Prepare data for test scripting args = [f"t{i}" for i in range(len(sample.input))] + \ [f"s{i}" for i in range(len(sample.args))] args_annot_kw = args + [f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()] args_kw = args + [f"{k}={v}" for k, v in sample.kwargs.items()] # Prepare data for test tracing sample_args_kwargs = () if len(sample.args) > 0: sample_args_kwargs += (sample.args, ) if len(sample.kwargs) > 0: sample_args_kwargs += (sample.kwargs, ) original_name = op.name original_name_inplace = original_name + "_" expected_dtype = op(*sample.input, *sample.args, **sample.kwargs).dtype for a_op in op.aliases: inplace = a_op.inplace_variant method_or_inplace = [a_op.inplace_variant, a_op.method_variant] variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None) # Test scripting: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name if variant in method_or_inplace: fn_template = ''' def _fn(t0{c}{args_annot_kw}): return t0.{alias_name}({args_kw}) ''' # remove the first input tensor script = fn_template.format( c=", " if len(args_kw[1:]) > 1 else "", args_annot_kw=", ".join(args_annot_kw[1:]), args_kw=", ".join(args_kw[1:]), alias_name=variant_name, ) else: fn_template = ''' def _fn({args_annot_kw}): return variant({args_kw}) ''' script = fn_template.format( args_annot_kw=", ".join(args_annot_kw), args_kw=", ".join(args_kw), ) scripted = torch.jit.CompilationUnit(script)._fn if (variant is inplace and not torch.can_cast(expected_dtype, dtype)): try: inp = (clone_input_helper(input) for input in sample.input) scripted(*inp, *sample.args, **sample.kwargs) except Exception as e: continue self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!") inp = (clone_input_helper(input) for input in sample.input) scripted(*inp, *sample.args, **sample.kwargs) inp = (clone_input_helper(input) for input in sample.input) graph = scripted.graph_for(*inp, *sample.args, **sample.kwargs) FileCheck().check(op_name).check_not(variant_name).run(graph) # Test tracing: for variant in variants: variant_name = variant.__name__ op_name = original_name_inplace if variant is inplace else original_name def _fn(*sample_args, **sample_kwargs): return variant(*sample_args, **sample_kwargs) inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs traced = torch.jit.trace(_fn, *inp) inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs traced(*inp) inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs graph = traced.graph_for(*inp) FileCheck().check(op_name).check_not(variant_name).run(graph)
def _test_consistency_helper(samples, variants): for sample in samples: # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList tensor = sample.input if isinstance( sample.input, torch.Tensor) else sample.input[0] # Computes function forward and backward values tensor.grad = None expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_grad = None output_process_fn_grad = sample.output_process_fn_grad if sample.output_process_fn_grad \ else lambda x: x # Skips inplace variants if the output dtype is not the same as # the input dtype skip_inplace = False if (isinstance(expected_forward, torch.Tensor) and expected_forward.dtype is not tensor.dtype): skip_inplace = True # TODO: backward consistency only supported for single tensor outputs # TODO: backward consistency only checked on sample.input, not all # tensor inputs # TODO: update to handle checking grads of all tensor inputs as # derived from each tensor output if (isinstance(expected_forward, torch.Tensor) and dtype in op.supported_backward_dtypes( torch.device(device).type)): output_process_fn_grad(expected_forward).sum().backward() expected_grad = tensor.grad # Test eager consistency for variant in variants: # Skips inplace ops if variant in inplace_ops and skip_inplace: continue # Compares variant's forward # Note: copies the to-be-modified input when testing the inplace variant tensor.grad = None cloned = clone_input_helper( sample.input ) if variant in inplace_ops else sample.input if variant in inplace_ops and sample.broadcasts_input: with self.assertRaises( RuntimeError, msg= ('inplace variant either incorrectly allowed ' 'resizing or you have marked the sample {}' ' incorrectly with `broadcasts_self=True'.format( sample.summary()))): variant_forward = variant(cloned, *sample.args, **sample.kwargs) continue variant_forward = variant(cloned, *sample.args, **sample.kwargs) self.assertEqual(expected_forward, variant_forward) # Compares variant's backward if expected_grad is not None and \ (variant not in inplace_ops or op.supports_inplace_autograd): output_process_fn_grad( variant_forward).sum().backward() self.assertEqual(expected_grad, tensor.grad)