def get_nn_functional_compiled_fn_and_inputs(name,
                                             self_size,
                                             args,
                                             variant_name='',
                                             *extra_args):
    test_name = 'test_nn_' + name

    if variant_name != '':
        test_name = test_name + '_' + variant_name

    no_grad = variant_name == 'inplace'

    self_variable = create_input((self_size, ))[0][0]
    kwargs = None

    # need to record this because methods can change the size (e.g. unsqueeze)
    args_variable, kwargs_variable = create_input(args)

    self_tensor = deepcopy(self_variable.data)
    args_tensor = deepcopy(unpack_variables(args_variable))

    f_args_variable = (self_variable, ) + args_variable
    f_args_tensor = (self_tensor, ) + args_tensor
    with torch.jit._disable_emit_hooks():
        script_fn, inputs = gen_script_fn_and_args(name, "nn_functional",
                                                   *f_args_variable)
    return script_fn, inputs
Exemple #2
0
    def test_aliases(self):
        # tests that op aliases are correctly being normalized
        # does not check for other properties such as correctness because
        # the common method registry gets tested for those in test_jit.py

        op_registry = {}
        for op in method_tests():
            op_registry[op[0]] = op

        for alias, mapping in op_alias_mappings.items():
            assert alias in op_registry, "Test not found for {} alias".format(alias)

            name, self_size, args, kwargs, output_process_fn = get_defaults(*op_registry[alias])

            def fn(*inputs, **kwargs):
                attr = getattr(inputs[0], name)
                output = attr(*inputs[1:], **kwargs)
                return output_process_fn(output)

            self_variable = create_input((self_size,))[0][0]
            args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs)

            traced_fn = create_traced_fn(self, fn)
            inputs = (self_variable,) + args_variable
            traced_fn(*inputs, **kwargs)
            last_graph = traced_fn.last_graph
            FileCheck().check(mapping).check_not(alias).run(last_graph)

            script_fn = create_script_fn(self, name, 'method', output_process_fn)
            script_fn(*inputs, **kwargs)
            last_graph = script_fn.last_graph
            FileCheck().check(mapping).check_not(alias).run(last_graph)
def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
    name = get_nn_module_name_from_kwargs(**kwargs)

    if 'desc' in kwargs and 'eval' in kwargs['desc']:
        # eval() is not supported, so skip these tests
        return

    test_name = name
    if 'desc' in kwargs:
        test_name = "{}_{}".format(test_name, kwargs['desc'])
    test_name = get_nn_mod_test_name(**kwargs)

    if test_name in EXCLUDE_SCRIPT_MODULES:
        return
    if 'constructor' in kwargs:
        nn_module = kwargs['constructor']
    else:
        nn_module = getattr(torch.nn, name)

    if "FunctionalModule" in str(nn_module):
        return

    if 'constructor_args_fn' in kwargs:
        constructor_args = kwargs['constructor_args_fn']()
    else:
        constructor_args = kwargs.get('constructor_args', ())

    # Set up inputs from tuple of sizes or constructor fn
    input_dtype = torch.double
    if 'input_fn' in kwargs:
        input = kwargs['input_fn']()
        if isinstance(input, torch.Tensor):
            input = (input, )

        if all(tensor.is_complex() for tensor in input):
            input_dtype = torch.cdouble
    else:
        input = (kwargs['input_size'], )

    # Extra parameters to forward()
    if 'extra_args' in kwargs:
        input = input + kwargs['extra_args']

    if 'target_size' in kwargs:
        input = input + (kwargs['target_size'], )
    elif 'target_fn' in kwargs:
        if torch.is_tensor(input):
            input = (input, )
        input = input + (kwargs['target_fn'](), )

    args_variable, kwargs_variable = create_input(input, dtype=input_dtype)
    f_args_variable = deepcopy(unpack_variables(args_variable))
    out_var = deepcopy(f_args_variable)

    args, mod = f_args_variable, create_script_module(
        None, nn_module, constructor_args, *f_args_variable)(*f_args_variable)

    return mod, out_var