def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False): try: samples = op.sample_inputs(device, dtype, requires_grad=False) sample_input = first_sample(self, samples) input_args = [sample_input.input, *sample_input.args] expected_res = op(*input_args, **sample_input.kwargs) except Exception as e: if allow_eager_fail: return else: raise e func = op.get_op() traced_fn = create_traced_fn(self, func) # Have to run the traced function to actually generate the trace traced_fn(sample_input.input, *sample_input.args, **sample_input.kwargs) # Run the Dtype Analysis graph = traced_fn.graph # Note this is a cached graph input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)] input_tensors += [ v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor) ] self.prop_dtype_on_graph(graph, input_tensors) self.assert_output_dtype_equal(expected_res, graph)
def test_aliases(self): # tests that op aliases are correctly being normalized # does not check for other properties such as correctness because # the common method registry gets tested for those in test_jit.py op_registry = {} for op in method_tests(): op_registry[op[0]] = op for alias, mapping in op_alias_mappings.items(): assert alias in op_registry, "Test not found for {} alias".format(alias) name, self_size, args, kwargs, output_process_fn = get_defaults(*op_registry[alias]) def fn(*inputs, **kwargs): attr = getattr(inputs[0], name) output = attr(*inputs[1:], **kwargs) return output_process_fn(output) self_variable = create_input((self_size,))[0][0] args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs) traced_fn = create_traced_fn(self, fn) inputs = (self_variable,) + args_variable traced_fn(*inputs, **kwargs) last_graph = traced_fn.last_graph FileCheck().check(mapping).check_not(alias).run(last_graph) script_fn = create_script_fn(self, name, 'method', output_process_fn) script_fn(*inputs, **kwargs) last_graph = script_fn.last_graph FileCheck().check(mapping).check_not(alias).run(last_graph)
def test_variant_consistency_jit(self, device, dtype, op): _requires_grad = op.supports_autograd and ( dtype.is_floating_point or op.supports_complex_autograd) samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) for sample in samples: # Acquires variants to test func = op.get_op() method = op.get_method() variants = { # TODO: inplace tests currently fail, fix and add inplace variant 'function': func, 'method': method, } # Test traced and scripted consistency for func_type, variant in variants.items(): if variant is None: continue # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output check_against_reference(self, script_fn, func, out_fn, (sample.input, ) + sample.args, sample.kwargs, no_grad=not _requires_grad) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, func, out_fn, (sample.input, ) + sample.args, sample.kwargs, no_grad=not _requires_grad) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 and int64 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype in [torch.float32, torch.int32]: check_alias_annotation(name, (sample.input, ) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
def test_variant_consistency_jit(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") for sample in samples: # Acquires variants to test method = op.get_method() inplace = op.get_inplace() variants = (v for v in (method, inplace) if v is not None) # Adds function variant to variant list # TODO: inplace tests currently fail # variants = (v for v in (op, method, inplace) if v is not None) variants = (v for v in (op, method) if v is not None) # Test traced and scripted consistency for variant in variants: # Create accessor for script function variant if variant is op: name = op.name func_type = 'function' elif variant is method: name = op.name func_type = 'method' else: # variant is inplace assert variant is inplace name = op.name + "_" func_type = 'inplace' # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): def fn(*inputs, **kwargs): attr = getattr(inputs[0], name) output = attr(*inputs[1:], **kwargs) return op.output_func(output) # bfloat16 grad doesn't work for some operators dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \ if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) # Check scripted forward, grad, and grad grad script_fn = create_script_fn(self, name, func_type, op.output_func) check_against_reference( self, script_fn, fn, (*sample.input, ) + sample.args, sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) check_against_reference( self, traced_fn, fn, (*sample.input, ) + sample.args, sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 and int64 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype in [torch.float32, torch.int32]: check_alias_annotation(name, (*sample.input, ) + sample.args, sample.kwargs) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
def test_variant_consistency_jit(self, device, dtype, op): _requires_grad = op.supports_autograd and ( dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)) include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex samples = op.sample_inputs( device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs) # Acquires variants to test func = op.get_op() method = op.get_method() variants = { # TODO: inplace tests currently fail, fix and add inplace variant 'function': func, 'method': method, } # TODO: find better way to standardize on op registration itself.. has_fake_function = op.name in ["resize_", 'resize_as_'] if has_fake_function: variants = {'method': getattr(torch.Tensor, op.name)} samples = op.sample_inputs(device, dtype, requires_grad=False) support_script = op.supports_scripting tested = False for sample in samples: # Test traced and scripted consistency for func_type, variant in variants.items(): if variant is None: continue # scripting and check_alias_analysis do not work with lambdas # lambdas are typically used as a way to simulate methods without # functional variants, so rely on the other variant for testing # for now if is_lambda(variant): continue tested = True # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad if support_script: script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output def get_sample(): return clone_input_helper( sample.input ) if op.name[-1] == '_' else sample.input if support_script: check_against_reference( self, script_fn, func, out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check traced forward, grad, and grad grad # TODO: fix tracing here supports_tracing = not has_fake_function if op.assert_jit_shape_analysis: self.assertTrue(supports_tracing) if supports_tracing: traced_fn = create_traced_fn(self, variant) check_against_reference( self, traced_fn, func, out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype == torch.float32: # TODO: no reason why we cant run this with tracing graph if support_script and op.name != "rsub": check_alias_annotation(name, (get_sample(), ) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # TODO: use script graph as well checked_shape_analysis = False if supports_tracing: out = variant(get_sample(), *sample.args, **sample.kwargs) # right now, tuple of outputs and tensor output supported # TODO: list of tensor outputs tuple_of_tensors = isinstance(out, tuple) and all([ isinstance(elem, torch.Tensor) for elem in out ]) if isinstance(out, torch.Tensor) or tuple_of_tensors: if tuple_of_tensors: sizes = [elem.size() for elem in out] else: sizes = out.size() self.checkShapeAnalysis( sizes, traced_fn.graph, op.assert_jit_shape_analysis) checked_shape_analysis = True if op.assert_jit_shape_analysis: self.assertTrue(checked_shape_analysis) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes if supports_tracing: self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) if support_script: self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) assert tested, "JIT Test does not execute any logic"
def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function): _requires_grad = (dtype in op.supported_backward_dtypes( torch.device(device).type)) support_script = op.supports_scripting # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad if support_script: script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output def get_sample(): return clone_input_helper( sample.input) if op.name[-1] == '_' else sample.input if support_script: check_against_reference(self, script_fn, op.get_op(), out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check traced forward, grad, and grad grad # TODO: fix tracing here supports_tracing = op.supports_tracing and not has_fake_function if op.assert_jit_shape_analysis: self.assertTrue(supports_tracing) if supports_tracing: traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, op.get_op(), out_fn, (get_sample(), ) + sample.args, sample.kwargs, no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype == torch.float32: # TODO: no reason why we cant run this with tracing graph if support_script and op.name != "rsub": check_alias_annotation(name, (get_sample(), ) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # TODO: use script graph as well checked_shape_analysis = False if supports_tracing: out = variant(get_sample(), *sample.args, **sample.kwargs) # right now, tuple of outputs and tensor output supported # TODO: list of tensor outputs tuple_of_tensors = isinstance(out, tuple) and all( [isinstance(elem, torch.Tensor) for elem in out]) if isinstance(out, torch.Tensor) or tuple_of_tensors: if tuple_of_tensors: sizes = [elem.size() for elem in out] else: sizes = out.size() self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis) checked_shape_analysis = True if op.assert_jit_shape_analysis: self.assertTrue(checked_shape_analysis) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes if supports_tracing: self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) if support_script: self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)