def get_nn_functional_compiled_fn_and_inputs(name,
                                             self_size,
                                             args,
                                             variant_name='',
                                             *extra_args):
    test_name = 'test_nn_' + name

    if variant_name != '':
        test_name = test_name + '_' + variant_name

    no_grad = variant_name == 'inplace'

    self_variable = create_input((self_size, ))[0][0]
    kwargs = None

    # need to record this because methods can change the size (e.g. unsqueeze)
    args_variable, kwargs_variable = create_input(args)

    self_tensor = deepcopy(self_variable.data)
    args_tensor = deepcopy(unpack_variables(args_variable))

    f_args_variable = (self_variable, ) + args_variable
    f_args_tensor = (self_tensor, ) + args_tensor
    with torch.jit._disable_emit_hooks():
        script_fn, inputs = gen_script_fn_and_args(name, "nn_functional",
                                                   *f_args_variable)
    return script_fn, inputs
def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
    name = get_nn_module_name_from_kwargs(**kwargs)

    if 'desc' in kwargs and 'eval' in kwargs['desc']:
        # eval() is not supported, so skip these tests
        return

    test_name = name
    if 'desc' in kwargs:
        test_name = "{}_{}".format(test_name, kwargs['desc'])
    test_name = get_nn_mod_test_name(**kwargs)

    if test_name in EXCLUDE_SCRIPT_MODULES:
        return
    if 'constructor' in kwargs:
        nn_module = kwargs['constructor']
    else:
        nn_module = getattr(torch.nn, name)

    if "FunctionalModule" in str(nn_module):
        return

    if 'constructor_args_fn' in kwargs:
        constructor_args = kwargs['constructor_args_fn']()
    else:
        constructor_args = kwargs.get('constructor_args', ())

    # Set up inputs from tuple of sizes or constructor fn
    input_dtype = torch.double
    if 'input_fn' in kwargs:
        input = kwargs['input_fn']()
        if isinstance(input, torch.Tensor):
            input = (input, )

        if all(tensor.is_complex() for tensor in input):
            input_dtype = torch.cdouble
    else:
        input = (kwargs['input_size'], )

    # Extra parameters to forward()
    if 'extra_args' in kwargs:
        input = input + kwargs['extra_args']

    if 'target_size' in kwargs:
        input = input + (kwargs['target_size'], )
    elif 'target_fn' in kwargs:
        if torch.is_tensor(input):
            input = (input, )
        input = input + (kwargs['target_fn'](), )

    args_variable, kwargs_variable = create_input(input, dtype=input_dtype)
    f_args_variable = deepcopy(unpack_variables(args_variable))
    out_var = deepcopy(f_args_variable)

    args, mod = f_args_variable, create_script_module(
        None, nn_module, constructor_args, *f_args_variable)(*f_args_variable)

    return mod, out_var