def test_bias_as_arg(self): with enable_profiling_mode_for_profiling_tests(): def method1(x, weight, bias: Optional[torch.Tensor]): return torch.nn.functional.linear(x, weight, bias).relu() + 2 N = 10 x = torch.rand(N, N, requires_grad=True) weight = torch.rand(N, N, requires_grad=True) bias = None scripted = self.checkScript(method1, (x, weight, bias)) # check_types requires last_graph on scripted to be set, so we just skip it check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False) bias = torch.rand(N, N, requires_grad=True) scripted = self.checkScript(method1, (x, weight, bias)) # check_types requires last_graph on scripted to be set, so we just skip it check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False)
def test_constructed_bias(self): with enable_profiling_mode_for_profiling_tests(): def method1(x, weight, b1, b2): bias = b1 * b2 return torch.nn.functional.linear(x, weight, bias) N = 10 x = torch.rand(N, N, requires_grad=True) weight = torch.rand(N, N, requires_grad=True) b1 = torch.rand(N, N, requires_grad=True) b2 = torch.rand(N, N, requires_grad=True) scripted = self.checkScript(method1, (x, weight, b1, b2)) # check_types requires last_graph on scripted to be set, so we just skip it check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False)
def test_bias_as_module_attr(self): with enable_profiling_mode_for_profiling_tests(): class M(torch.nn.Module): def __init__(self, has_bias): super(M, self).__init__() self.ll = torch.nn.Linear(10, 10, has_bias) def forward(self, x, y): return self.ll(x + y) * x + y x = torch.rand(10, 10, requires_grad=True) no_bias = M(False) scripted_no_bias = torch.jit.script(no_bias) scripted_no_bias(x, x) scripted_no_bias(x, x) scripted_no_bias(x, x) has_bias = M(True) check_against_reference(self, scripted_no_bias, no_bias, lambda x: x, ( x, x, ), check_types=False) scripted_has_bias = torch.jit.script(has_bias) scripted_has_bias(x, x) scripted_has_bias(x, x) scripted_has_bias(x, x) check_against_reference(self, scripted_has_bias, has_bias, lambda x: x, ( x, x, ), check_types=False)
def test_variant_consistency_jit(self, device, dtype, op): _requires_grad = op.supports_autograd and ( dtype.is_floating_point or op.supports_complex_autograd) samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) for sample in samples: # Acquires variants to test func = op.get_op() method = op.get_method() variants = { # TODO: inplace tests currently fail, fix and add inplace variant 'function': func, 'method': method, } # Test traced and scripted consistency for func_type, variant in variants.items(): if variant is None: continue # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output check_against_reference(self, script_fn, func, out_fn, (sample.input, ) + sample.args, sample.kwargs, no_grad=not _requires_grad) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, func, out_fn, (sample.input, ) + sample.args, sample.kwargs, no_grad=not _requires_grad) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 and int64 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype in [torch.float32, torch.int32]: check_alias_annotation(name, (sample.input, ) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
def test_variant_consistency_jit(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") for sample in samples: # Acquires variants to test method = op.get_method() inplace = op.get_inplace() variants = (v for v in (method, inplace) if v is not None) # Adds function variant to variant list # TODO: inplace tests currently fail # variants = (v for v in (op, method, inplace) if v is not None) variants = (v for v in (op, method) if v is not None) # Test traced and scripted consistency for variant in variants: # Create accessor for script function variant if variant is op: name = op.name func_type = 'function' elif variant is method: name = op.name func_type = 'method' else: # variant is inplace assert variant is inplace name = op.name + "_" func_type = 'inplace' # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): def fn(*inputs, **kwargs): attr = getattr(inputs[0], name) output = attr(*inputs[1:], **kwargs) return op.output_func(output) # bfloat16 grad doesn't work for some operators dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \ if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) # Check scripted forward, grad, and grad grad script_fn = create_script_fn(self, name, func_type, op.output_func) check_against_reference( self, script_fn, fn, (*sample.input, ) + sample.args, sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) check_against_reference( self, traced_fn, fn, (*sample.input, ) + sample.args, sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 and int64 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype in [torch.float32, torch.int32]: check_alias_annotation(name, (*sample.input, ) + sample.args, sample.kwargs) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
def test_variant_consistency_jit(self, device, dtype, op): _requires_grad = op.supports_autograd and ( dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)) include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex samples = op.sample_inputs( device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs) # Acquires variants to test func = op.get_op() method = op.get_method() variants = { # TODO: inplace tests currently fail, fix and add inplace variant 'function': func, 'method': method, } # TODO: find better way to standardize on op registration itself.. has_fake_function = op.name in ["resize_", 'resize_as_'] if has_fake_function: variants = {'method': getattr(torch.Tensor, op.name)} samples = op.sample_inputs(device, dtype, requires_grad=False) support_script = op.supports_scripting tested = False for sample in samples: # Test traced and scripted consistency for func_type, variant in variants.items(): if variant is None: continue # scripting and check_alias_analysis do not work with lambdas # lambdas are typically used as a way to simulate methods without # functional variants, so rely on the other variant for testing # for now if is_lambda(variant): continue tested = True # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad if support_script: script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output def get_sample(): return clone_input_helper( sample.input ) if op.name[-1] == '_' else sample.input if support_script: check_against_reference( self, script_fn, func, out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check traced forward, grad, and grad grad # TODO: fix tracing here supports_tracing = not has_fake_function if op.assert_jit_shape_analysis: self.assertTrue(supports_tracing) if supports_tracing: traced_fn = create_traced_fn(self, variant) check_against_reference( self, traced_fn, func, out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype == torch.float32: # TODO: no reason why we cant run this with tracing graph if support_script and op.name != "rsub": check_alias_annotation(name, (get_sample(), ) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # TODO: use script graph as well checked_shape_analysis = False if supports_tracing: out = variant(get_sample(), *sample.args, **sample.kwargs) # right now, tuple of outputs and tensor output supported # TODO: list of tensor outputs tuple_of_tensors = isinstance(out, tuple) and all([ isinstance(elem, torch.Tensor) for elem in out ]) if isinstance(out, torch.Tensor) or tuple_of_tensors: if tuple_of_tensors: sizes = [elem.size() for elem in out] else: sizes = out.size() self.checkShapeAnalysis( sizes, traced_fn.graph, op.assert_jit_shape_analysis) checked_shape_analysis = True if op.assert_jit_shape_analysis: self.assertTrue(checked_shape_analysis) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes if supports_tracing: self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) if support_script: self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) assert tested, "JIT Test does not execute any logic"
def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function): _requires_grad = (dtype in op.supported_backward_dtypes( torch.device(device).type)) support_script = op.supports_scripting # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad if support_script: script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output def get_sample(): return clone_input_helper( sample.input) if op.name[-1] == '_' else sample.input if support_script: check_against_reference(self, script_fn, op.get_op(), out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check traced forward, grad, and grad grad # TODO: fix tracing here supports_tracing = op.supports_tracing and not has_fake_function if op.assert_jit_shape_analysis: self.assertTrue(supports_tracing) if supports_tracing: traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, op.get_op(), out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype == torch.float32: # TODO: no reason why we cant run this with tracing graph if support_script and op.name != "rsub": check_alias_annotation(name, (get_sample(), ) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # TODO: use script graph as well checked_shape_analysis = False if supports_tracing: out = variant(get_sample(), *sample.args, **sample.kwargs) # right now, tuple of outputs and tensor output supported # TODO: list of tensor outputs tuple_of_tensors = isinstance(out, tuple) and all( [isinstance(elem, torch.Tensor) for elem in out]) if isinstance(out, torch.Tensor) or tuple_of_tensors: if tuple_of_tensors: sizes = [elem.size() for elem in out] else: sizes = out.size() self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis) checked_shape_analysis = True if op.assert_jit_shape_analysis: self.assertTrue(checked_shape_analysis) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes if supports_tracing: self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) if support_script: self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)