def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args): test_name = 'test_nn_' + name if variant_name != '': test_name = test_name + '_' + variant_name no_grad = variant_name == 'inplace' self_variable = create_input((self_size, ))[0][0] kwargs = None # need to record this because methods can change the size (e.g. unsqueeze) args_variable, kwargs_variable = create_input(args) self_tensor = deepcopy(self_variable.data) args_tensor = deepcopy(unpack_variables(args_variable)) f_args_variable = (self_variable, ) + args_variable f_args_tensor = (self_tensor, ) + args_tensor with torch.jit._disable_emit_hooks(): script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable) return script_fn, inputs
def test_aliases(self): # tests that op aliases are correctly being normalized # does not check for other properties such as correctness because # the common method registry gets tested for those in test_jit.py op_registry = {} for op in method_tests(): op_registry[op[0]] = op for alias, mapping in op_alias_mappings.items(): assert alias in op_registry, "Test not found for {} alias".format(alias) name, self_size, args, kwargs, output_process_fn = get_defaults(*op_registry[alias]) def fn(*inputs, **kwargs): attr = getattr(inputs[0], name) output = attr(*inputs[1:], **kwargs) return output_process_fn(output) self_variable = create_input((self_size,))[0][0] args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs) traced_fn = create_traced_fn(self, fn) inputs = (self_variable,) + args_variable traced_fn(*inputs, **kwargs) last_graph = traced_fn.last_graph FileCheck().check(mapping).check_not(alias).run(last_graph) script_fn = create_script_fn(self, name, 'method', output_process_fn) script_fn(*inputs, **kwargs) last_graph = script_fn.last_graph FileCheck().check(mapping).check_not(alias).run(last_graph)
def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs): name = get_nn_module_name_from_kwargs(**kwargs) if 'desc' in kwargs and 'eval' in kwargs['desc']: # eval() is not supported, so skip these tests return test_name = name if 'desc' in kwargs: test_name = "{}_{}".format(test_name, kwargs['desc']) test_name = get_nn_mod_test_name(**kwargs) if test_name in EXCLUDE_SCRIPT_MODULES: return if 'constructor' in kwargs: nn_module = kwargs['constructor'] else: nn_module = getattr(torch.nn, name) if "FunctionalModule" in str(nn_module): return if 'constructor_args_fn' in kwargs: constructor_args = kwargs['constructor_args_fn']() else: constructor_args = kwargs.get('constructor_args', ()) # Set up inputs from tuple of sizes or constructor fn input_dtype = torch.double if 'input_fn' in kwargs: input = kwargs['input_fn']() if isinstance(input, torch.Tensor): input = (input, ) if all(tensor.is_complex() for tensor in input): input_dtype = torch.cdouble else: input = (kwargs['input_size'], ) # Extra parameters to forward() if 'extra_args' in kwargs: input = input + kwargs['extra_args'] if 'target_size' in kwargs: input = input + (kwargs['target_size'], ) elif 'target_fn' in kwargs: if torch.is_tensor(input): input = (input, ) input = input + (kwargs['target_fn'](), ) args_variable, kwargs_variable = create_input(input, dtype=input_dtype) f_args_variable = deepcopy(unpack_variables(args_variable)) out_var = deepcopy(f_args_variable) args, mod = f_args_variable, create_script_module( None, nn_module, constructor_args, *f_args_variable)(*f_args_variable) return mod, out_var