Ejemplo n.º 1
0
    def check_device(device, host="llvm"):
        ctx = tvm.context(device, 0)
        if not tvm.runtime.enabled(host):
            return
        if not ctx.exist:
            print("skip because %s is not enabled.." % device)
            return

        sout = te.create_schedule(out.op)
        mout = tvm.build(sout, [out] + inputs + args)
        out_shape = get_const_tuple(out.shape)

        l, h = data_range
        input_data = [
            tvm.nd.array(
                np.random.uniform(l, h, size=get_const_tuple(
                    input.shape)).astype(input.dtype)) for input in inputs
        ]
        arg_vals = [
            tvm.nd.array(
                np.random.uniform(l, h, size=get_const_tuple(
                    arg.shape)).astype(arg.dtype)) for arg in args
        ]

        ones = topi.full_like(out, 1.0)
        # we provide head to sum and reduce the output dimension,
        # which equals to grad(out.sum(), inputs)
        grads = te.gradient(out, inputs, head=ones)
        grad_sched = te.create_schedule([grad.op for grad in grads])
        mgrad = tvm.build(grad_sched, list(grads) + inputs + args)
        if assert_no_jacobian:
            # TODO(yzhliu): it is better to visit the expression and do assertion
            lowered_ir = str(
                tvm.lower(grad_sched,
                          list(grads) + inputs + args,
                          simple_mode=True))
            assert "jacobian" not in lowered_ir, lowered_ir

        grad_data = [
            tvm.nd.empty(get_const_tuple(i.shape), g.dtype)
            for i, g in zip(inputs, grads)
        ]

        mgrad(*grad_data, *input_data, *arg_vals)
        g_res = [g.asnumpy() for g in grad_data]

        if desired_grads:
            assert isinstance(desired_grads, list)
            for actual, desired in zip(g_res, desired_grads):
                assert_allclose(actual, desired, rtol=0.1, atol=1e-2)
        else:

            def forward(*in_data):
                out_data = tvm.nd.empty(out_shape, out.dtype)
                mout(out_data, *[tvm.nd.array(d) for d in list(in_data)])
                return out_data.asnumpy().sum()

            check_numerical_grads(forward,
                                  [d.asnumpy() for d in input_data + arg_vals],
                                  g_res)
Ejemplo n.º 2
0
def test_check_numerical_grads():
    # Functions and their derivatives
    functions = [
        lambda x: (x * x * x, 3 * x * x),
        lambda x: (x * x, 2 * x),
        lambda x: (np.abs(x), np.sign(x)),
        lambda x: (np.log(np.abs(x)), 1 / x),
        lambda x: (np.sqrt(np.abs(x)), np.sign(x) / (2 * np.sqrt(np.abs(x)))),
        lambda x: (1 / x, -1 / (x * x)),
        lambda x: (np.sign(np.sin(1 / x)), np.zeros_like(x)),
        lambda x: (x * np.sin(1 / x), np.sin(1 / x) - np.cos(1 / x) / x),
        lambda x: (np.sin(1 / x), -np.cos(1 / x) / (x * x)),
    ]

    # Avoid values too close to 0 since singularities of our functions are there
    min_x = 0.5

    for func in functions:
        x_input = np.random.uniform(min_x, 10, size=(3, 4))

        # We need a function returning a scalar, so sum the results
        func_forw = lambda x: np.sum(func(x)[0])
        grads = [func(x_input)[1]]

        check_numerical_grads(func_forw, [x_input], grads)

    # Check functions with multiple arguments
    for f1 in functions:
        for f2 in functions:
            x_input = np.random.uniform(min_x, 10, size=(3, 4))
            y_input = np.random.uniform(min_x, 10, size=(3, 4))

            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = [f1(x_input)[1], f2(y_input)[1]]

            check_numerical_grads(func_forw, [x_input, y_input], grads)

            # Same thing but with keyword arguments
            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = {'x': f1(x_input)[1], 'y': f2(y_input)[1]}

            check_numerical_grads(func_forw, {
                'x': x_input,
                'y': y_input
            }, grads)

    def _noise1(x, atol=1e-2, rtol=0.1):
        # We go in random direction using twice the original tolerance to be sure this
        # results in an error
        sqrt_n = np.sqrt(float(np.prod(x.shape)))
        tol = 2 * (np.linalg.norm(x) * rtol + atol * sqrt_n)
        noise = np.random.normal(size=x.shape)
        noise = tol * noise / np.linalg.norm(noise)
        return x + noise

    def _noise2(x, atol=1e-2, rtol=0.1):
        # This noise affects just a single component
        sqrt_n = np.sqrt(float(np.prod(x.shape)))
        tol = 2 * (np.linalg.norm(x) * rtol + atol * sqrt_n)
        n = np.random.randint(np.prod(x.shape))
        noise = np.zeros_like(x)
        noise.reshape(-1)[n] = tol
        return x + noise

    # Add noise to gradients and check that the function throws
    for f1 in functions:
        for f2 in functions:
            x_input = np.random.uniform(min_x, 10, size=(3, 4))
            y_input = np.random.uniform(min_x, 10, size=(3, 4))

            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = [_noise1(f1(x_input)[1]), _noise1(f2(y_input)[1])]

            try:
                check_numerical_grads(func_forw, [x_input, y_input], grads)
            except AssertionError as e:
                pass
            else:
                raise AssertionError(
                    "check_numerical_grads didn't raise an exception")

            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = {
                'x': _noise2(f1(x_input)[1]),
                'y': _noise2(f2(y_input)[1])
            }

            try:
                check_numerical_grads(func_forw, {
                    'x': x_input,
                    'y': y_input
                }, grads)
            except AssertionError as e:
                pass
            else:
                raise AssertionError(
                    "check_numerical_grads didn't raise an exception")
Ejemplo n.º 3
0
def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
                   shape=None, dtype=None, in_range=None, values=None,
                   exclude_targets=None, only_targets=None,
                   additional_params=None,
                   numerical_grads=None, numerical_grads_params=None,
                   atol=1e-5, rtol=1e-5, quiet=False):
    """Compute the function and/or its gradients on a random input and raise
    an exception if the result doesn't match the reference implementation.

    Parameters
    ----------
    symbol : nnvm.Symbol
        A symbol representing the output.

    forward : Callable[..., List[numpy.ndarray]], optional
        A reference implementation to compare with.

    backward : Callable[..., List[numpy.ndarray] or Dict[str, numpy.ndarray]], optional
        A reference implementation of gradients. Should also accept head_grads besides
        normal inputs which is a list of gradients of some scalar wrt the outputs or just a
        single gradient if there are multiple outputs.
        Should return either a dict mapping input variable names to the respective
        gradients or a list of gradients wrt variables from grad_input_vars in
        exactly the same order (in alphabetical order by default).

    grad_input_vars : List[nnvm.Symbol or str], optional
        A list of variables with respect to which the gradients will be computed.
        None (default) means that all input variables will be used in an alphabetical order.

    shape : Dict[nnvm.Symbol or str, Tuple[int]] or Tuple[int], optional
        A dict mapping input variable names to shapes, or just a single shape.
        By default shapes will be inferred from variables' attributes (see the Examples).
        Note that this parameter takes precedence over variables' attributes.

    dtype : Dict[nnvm.Symbol or str, str] or str, optional
        A dict mapping input variable names to dtypes, or just a single dtype.
        By default dtypes will be inferred from variables' attributes (see the Examples).
        If dtypes cannot be inferred for some variables then float32 will be used as a fallback.
        Note that this parameter takes precedence over variables' attributes.

    in_range : Dict[nnvm.Symbol or str, (float, float)] or (float, float), optional
        A dict mapping input variable names to ranges or just a single range
        (the same for all variables). Input values will be generated from
        uniform distributions on these ranges. `head_grads` can also be
        assigned a range this way.

    values : Dict[nnvm.Symbol or str, numpy.ndarray], optional
        A dict explicitly providing values for some variables instead of random generation.

    exclude_targets : Set[str], optional
        Skip compiling and running anything for these targets.

    only_targets : Set[str], optional
        Test only for those targets from `ctx_list()` that are also in this set.

    additional_params : dict, optional
        A dict of additional parameters which will be passed to forward and backward.

    numerical_grads : bool or 'if_possible', optional
        Whether to additionally check against numerically computed gradients. If 'if_possible' or
        None is passed (which is the default) then it will try to create a gradient computation
        graph and then check gradients numerically only if this graph can be created (i.e. if there
        are some operations with unimplemented gradients, it will just issue a warning).
        Checking against numerical gradients is done via the `check_numerical_grads` function.

    numerical_grads_params : dict, optional
        Additional parameters for `check_numerical_grads`.

    atol : float, optional
        Absolute tolerance for `tvm.testing.assert_allclose`. NOT used for numerical gradients.

    rtol : float, optional
        Relative tolerance for `tvm.testing.assert_allclose`. NOT used for numerical gradients.

    quiet : bool, optional
        Don't dump additional information to stdout on failure.

    Examples
    --------
    .. code-block:: python

        x = sym.Variable("x", shape=(1, 2))
        y = sym.Variable("y", shape=(1, 2))

        # check the function and its gradients both numerically and using a reference function
        check_function(x + 2*y,
                       lambda x, y: x + 2*y,
                       lambda x, y, head_grads: {'x': head_grads, 'y': 2*head_grads})

        # just check gradients numerically
        check_function(x + 2*y, numerical_grads=True)

        # just check the forward computation
        check_function(x + 2*y, lambda x, y: x + 2*y, numerical_grads=False)

        # specifying dtype
        check_function(x + 2*y, lambda x, y: x + 2*y, dtype='float64')

        # dtypes can also be specified during variable creation with dtype codes
        x = sym.Variable("x", dtype=0)
        check_function(x + 1, shape=(2, 2), numerical_grads=True)
    """
    # validate and preprocess the input params
    if numerical_grads is None and forward is None and backward is None:
        raise ValueError("No reference function was passed to check_function. If you only want to "
                         "check gradients numerically, pass numerical_grads=True explicitly.")

    if numerical_grads is None:
        numerical_grads = 'if_possible'

    if numerical_grads not in [False, True, 'if_possible']:
        raise ValueError("numerical_grads must be a bool or 'if_possible', not {}"
                         .format(numerical_grads))

    if additional_params is None:
        additional_params = {}

    input_vars = symbol.list_input_variables()
    input_dict = {x.attr('name'): x for x in input_vars}

    if grad_input_vars is None:
        grad_input_vars = sorted(input_vars, key=lambda x: x.attr('name'))
    else:
        grad_input_vars = [input_dict[x] if isinstance(x, str) else x for x in grad_input_vars]

    in_range = _dict_var_to_dict_str(in_range)
    values = _dict_var_to_dict_str(values)

    out_len = len(symbol.list_output_names())

    # Infer the output shapes and dtypes, and preprocess the shape and dtype params
    forward_graph, shape, dtype, out_shapes, out_dtypes = \
        infer_shapes_dtypes(nnvm.graph.create(symbol), shape=shape, dtype=dtype,
                            fallback_dtype='float32')

    if not all(out_shapes) or not all(out_dtypes):
        if not quiet:
            print(forward_graph.ir(join_node_attrs=['shape', 'dtype']))
        raise ValueError("Could not infer shapes or dtypes for outputs.\n"
                         "out_shapes = {}\nout_dtypes = {}".format(out_shapes, out_dtypes))

    backward_graph = None

    # If we want gradients, we have to recreate the graph, but now with gradient computations
    # Note that here we need out_shapes for defining the shape of head grads, so we have to
    # create the graph twice
    if backward is not None or numerical_grads:
        try:
            head_grads_symbols = [nnvm.symbol.Variable("head_grads_" + str(i),
                                                       shape=out_shapes[i],
                                                       dtype=DTYPE_TO_TCODE[out_dtypes[i]])
                                  for i in range(out_len)]
            grad_symbols = graph_util.gradients([symbol], grad_input_vars,
                                                grad_ys=head_grads_symbols)
            # Sometimes grads do not depend on head_grads, so head_grads does not appear
            # in the variable list; adding it manually prevents this, making things a bit easier
            backward_graph = \
                nnvm.graph.create(nnvm.symbol.Group([symbol] + grad_symbols + head_grads_symbols))

            backward_graph, shape, dtype, out_shapes, out_dtypes = \
                infer_shapes_dtypes(backward_graph, shape=shape, dtype=dtype,
                                    fallback_dtype='float32')
        except nnvm._base.NNVMError as err:
            if backward is None and numerical_grads == "if_possible":
                logging.warning("Won't check gradients because: %s", str(err).split('\n', 1)[0])
                numerical_grads = False
                backward_graph = None
            else:
                raise

    main_graph = backward_graph if backward_graph is not None else forward_graph

    # Generate random data for inputs (including head_grads)

    np_inputs = {}

    for x in main_graph.symbol.list_input_variables():
        x_name = x.attr('name')
        x_shape = shape[x_name]
        x_dtype = dtype[x_name]

        if values is not None and x_name in values:
            np_inputs[x_name] = values[x_name].astype(x_dtype)
            continue

        low = -1.0
        high = 1.0
        if in_range is not None:
            if isinstance(in_range, dict):
                if x_name in in_range:
                    low = in_range[x_name][0]
                    high = in_range[x_name][1]
            else:
                low = in_range[0]
                high = in_range[1]

        np_inputs[x_name] = np.random.uniform(size=x_shape, low=low, high=high).astype(x_dtype)

    np_inputs_without_head_grads = {k: np_inputs[k] for k in np_inputs
                                    if not k.startswith('head_grads_')}

    nothing_was_done = True

    # Compute and compare the results
    for target, ctx in ctx_list():
        if exclude_targets is not None:
            if target in exclude_targets or str(target) in exclude_targets:
                logging.info("Skipping target = %s, ctx = %s", target, ctx)
                continue
        if only_targets is not None:
            if target not in only_targets and str(target) not in only_targets:
                logging.info("Skipping target = %s, ctx = %s", target, ctx)
                continue

        logging.info("Checking computation on target = %s, ctx = %s", target, ctx)

        debug_stage = None

        try:
            nnvm_res = None

            debug_stage = "compiling"
            main_function = graph_to_function(main_graph, target, ctx)

            # nnvm_res contains the output and gradients (if they are needed)
            debug_stage = "running"
            nnvm_res = main_function(**np_inputs)

            try:
                logging.debug("checking to_relay conversion")
                inputs = np_inputs_without_head_grads.copy()
                func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
                with relay.build_config(opt_level=3):
                    graph, lib, params = relay.build(func, target=target)
                m = graph_runtime.create(graph, lib, ctx)
                m.set_input(**inputs)
                m.set_input(**params)
                m.run()
                for i in range(out_len):
                    relay_out = m.get_output(i).asnumpy()
                    tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
            except NotImplementedError as err:
                # the NNVM operator is not supported yet
                logging.warning(err)

            if backward_graph is not None:
                grad_var_names = [x.attr('name') for x in grad_input_vars]
                nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}

            if forward is not None:
                nothing_was_done = False
                debug_stage = "checking forward computation"
                logging.debug(debug_stage)

                params = {}
                params.update(np_inputs_without_head_grads)
                params.update(additional_params)
                numpy_res = forward(**params)

                if isinstance(numpy_res, tuple):
                    numpy_res = list(numpy_res)

                if not isinstance(numpy_res, list):
                    numpy_res = [numpy_res]

                if len(numpy_res) != out_len:
                    raise ValueError("Forward function returned {} values, but "
                                     "the nnvm graph returns {} values"
                                     .format(len(numpy_res), out_len))

                for i in range(out_len):
                    tvm.testing.assert_allclose(nnvm_res[i], numpy_res[i], atol=atol, rtol=rtol)

            if backward is not None:
                nothing_was_done = False
                debug_stage = "checking gradients"
                logging.debug(debug_stage)

                np_head_grads = [np_inputs["head_grads_" + str(i)] for i in range(out_len)]

                if out_len == 1:
                    np_head_grads = np_head_grads[0]

                params = {'head_grads': np_head_grads}
                params.update(np_inputs_without_head_grads)
                params.update(additional_params)
                numpy_grads = backward(**params)

                if not isinstance(numpy_grads, dict):
                    if isinstance(numpy_grads, tuple):
                        numpy_grads = list(numpy_grads)
                    if not isinstance(numpy_grads, list):
                        numpy_grads = [numpy_grads]
                    numpy_grads = {x: v for x, v in zip(grad_var_names, numpy_grads)}
                    if len(numpy_grads) != len(grad_var_names):
                        raise ValueError("The backward function returns a list of gradients which "
                                         "does not contain gradients for these variables: {}"
                                         .format(set(grad_var_names) - set(numpy_grads)))

                for x_name in numpy_grads:
                    tvm.testing.assert_allclose(nnvm_grads[x_name], numpy_grads[x_name],
                                                atol=atol, rtol=rtol)

            if numerical_grads:
                nothing_was_done = False
                debug_stage = "checking gradients numerically"
                logging.debug(debug_stage)

                forward_function = graph_to_function(forward_graph, target, ctx)

                # Since the result may be non-scalar, we have to put another operation on the top,
                # so we just multiple by the randomly generated head_grads and then sum everything.
                # This way we can reuse the gradient values which has been already computed.
                def scalar_function(**kwargs):
                    res = forward_function(**kwargs)
                    return np.sum([np.dot(np_inputs['head_grads_' + str(i)].ravel(), res[i].ravel())
                                   for i in range(out_len)])

                if numerical_grads_params is None:
                    numerical_grads_params = {}

                check_numerical_grads(
                    scalar_function,
                    input_values=np_inputs_without_head_grads,
                    grad_values=nnvm_grads,
                    **numerical_grads_params)

        except:
            if not quiet:
                print("\ncheck_function failed while {}, here is the main graph"
                      .format(debug_stage))
                print(main_graph.ir(join_node_attrs=['shape', 'dtype']))
                if nnvm_res is not None:
                    print("Generated inputs:")
                    print(np_inputs)
                    print()
            raise

    if nothing_was_done:
        logging.warning("Nothing was done in check_function. Check ctx_list().")
Ejemplo n.º 4
0
def test_check_numerical_grads():
    # Functions and their derivatives
    functions = [
        lambda x: (x*x*x, 3*x*x),
        lambda x: (x*x, 2*x),
        lambda x: (np.abs(x), np.sign(x)),
        lambda x: (np.log(np.abs(x)), 1/x),
        lambda x: (np.sqrt(np.abs(x)), np.sign(x)/(2*np.sqrt(np.abs(x)))),
        lambda x: (1/x, -1/(x*x)),
        lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)),
        lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x),
        lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)),
    ]

    # Avoid values too close to 0 since singularities of our functions are there
    min_x = 0.5

    for func in functions:
        x_input = np.random.uniform(min_x, 10, size=(3, 4))

        # We need a function returning a scalar, so sum the results
        func_forw = lambda x: np.sum(func(x)[0])
        grads = [func(x_input)[1]]

        check_numerical_grads(func_forw, [x_input], grads)

    # Check functions with multiple arguments
    for f1 in functions:
        for f2 in functions:
            x_input = np.random.uniform(min_x, 10, size=(3, 4))
            y_input = np.random.uniform(min_x, 10, size=(3, 4))

            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = [f1(x_input)[1], f2(y_input)[1]]

            check_numerical_grads(func_forw, [x_input, y_input], grads)

            # Same thing but with keyword arguments
            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = {'x': f1(x_input)[1], 'y': f2(y_input)[1]}

            check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads)

    def _noise1(x, atol=1e-2, rtol=0.1):
        # We go in random direction using twice the original tolerance to be sure this
        # results in an error
        sqrt_n = np.sqrt(float(np.prod(x.shape)))
        tol = 2*(np.linalg.norm(x)*rtol + atol*sqrt_n)
        noise = np.random.normal(size=x.shape)
        noise = tol * noise / np.linalg.norm(noise)
        return x + noise

    def _noise2(x, atol=1e-2, rtol=0.1):
        # This noise affects just a single component
        sqrt_n = np.sqrt(float(np.prod(x.shape)))
        tol = 2*(np.linalg.norm(x)*rtol + atol*sqrt_n)
        n = np.random.randint(np.prod(x.shape))
        noise = np.zeros_like(x)
        noise.reshape(-1)[n] = tol
        return x + noise

    # Add noise to gradients and check that the function throws
    for f1 in functions:
        for f2 in functions:
            x_input = np.random.uniform(min_x, 10, size=(3, 4))
            y_input = np.random.uniform(min_x, 10, size=(3, 4))

            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = [_noise1(f1(x_input)[1]), _noise1(f2(y_input)[1])]

            try:
                check_numerical_grads(func_forw, [x_input, y_input], grads)
            except AssertionError as e:
                pass
            else:
                raise AssertionError("check_numerical_grads didn't raise an exception")

            func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
            grads = {'x': _noise2(f1(x_input)[1]), 'y': _noise2(f2(y_input)[1])}

            try:
                check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads)
            except AssertionError as e:
                pass
            else:
                raise AssertionError("check_numerical_grads didn't raise an exception")
Ejemplo n.º 5
0
def check_function(symbol,
                   forward=None,
                   backward=None,
                   grad_input_vars=None,
                   shape=None,
                   dtype=None,
                   in_range=None,
                   values=None,
                   exclude_targets=None,
                   only_targets=None,
                   additional_params=None,
                   numerical_grads=None,
                   numerical_grads_params=None,
                   atol=1e-5,
                   rtol=1e-5,
                   quiet=False):
    """Compute the function and/or its gradients on a random input and raise
    an exception if the result doesn't match the reference implementation.

    Parameters
    ----------
    symbol : nnvm.Symbol
        A symbol representing the output.

    forward : Callable[..., List[numpy.ndarray]], optional
        A reference implementation to compare with.

    backward : Callable[..., List[numpy.ndarray] or Dict[str, numpy.ndarray]], optional
        A reference implementation of gradients. Should also accept head_grads besides
        normal inputs which is a list of gradients of some scalar wrt the outputs or just a
        single gradient if there are multiple outputs.
        Should return either a dict mapping input variable names to the respective
        gradients or a list of gradients wrt variables from grad_input_vars in
        exactly the same order (in alphabetical order by default).

    grad_input_vars : List[nnvm.Symbol or str], optional
        A list of variables with respect to which the gradients will be computed.
        None (default) means that all input variables will be used in an alphabetical order.

    shape : Dict[nnvm.Symbol or str, Tuple[int]] or Tuple[int], optional
        A dict mapping input variable names to shapes, or just a single shape.
        By default shapes will be inferred from variables' attributes (see the Examples).
        Note that this parameter takes precedence over variables' attributes.

    dtype : Dict[nnvm.Symbol or str, str] or str, optional
        A dict mapping input variable names to dtypes, or just a single dtype.
        By default dtypes will be inferred from variables' attributes (see the Examples).
        If dtypes cannot be inferred for some variables then float32 will be used as a fallback.
        Note that this parameter takes precedence over variables' attributes.

    in_range : Dict[nnvm.Symbol or str, (float, float)] or (float, float), optional
        A dict mapping input variable names to ranges or just a single range
        (the same for all variables). Input values will be generated from
        uniform distributions on these ranges. `head_grads` can also be
        assigned a range this way.

    values : Dict[nnvm.Symbol or str, numpy.ndarray], optional
        A dict explicitly providing values for some variables instead of random generation.

    exclude_targets : Set[str], optional
        Skip compiling and running anything for these targets.

    only_targets : Set[str], optional
        Test only for those targets from `ctx_list()` that are also in this set.

    additional_params : dict, optional
        A dict of additional parameters which will be passed to forward and backward.

    numerical_grads : bool or 'if_possible', optional
        Whether to additionally check against numerically computed gradients. If 'if_possible' or
        None is passed (which is the default) then it will try to create a gradient computation
        graph and then check gradients numerically only if this graph can be created (i.e. if there
        are some operations with unimplemented gradients, it will just issue a warning).
        Checking against numerical gradients is done via the `check_numerical_grads` function.

    numerical_grads_params : dict, optional
        Additional parameters for `check_numerical_grads`.

    atol : float, optional
        Absolute tolerance for `tvm.testing.assert_allclose`. NOT used for numerical gradients.

    rtol : float, optional
        Relative tolerance for `tvm.testing.assert_allclose`. NOT used for numerical gradients.

    quiet : bool, optional
        Don't dump additional information to stdout on failure.

    Examples
    --------
    .. code-block:: python

        x = sym.Variable("x", shape=(1, 2))
        y = sym.Variable("y", shape=(1, 2))

        # check the function and its gradients both numerically and using a reference function
        check_function(x + 2*y,
                       lambda x, y: x + 2*y,
                       lambda x, y, head_grads: {'x': head_grads, 'y': 2*head_grads})

        # just check gradients numerically
        check_function(x + 2*y, numerical_grads=True)

        # just check the forward computation
        check_function(x + 2*y, lambda x, y: x + 2*y, numerical_grads=False)

        # specifying dtype
        check_function(x + 2*y, lambda x, y: x + 2*y, dtype='float64')

        # dtypes can also be specified during variable creation with dtype codes
        x = sym.Variable("x", dtype=0)
        check_function(x + 1, shape=(2, 2), numerical_grads=True)
    """
    # validate and preprocess the input params
    if numerical_grads is None and forward is None and backward is None:
        raise ValueError(
            "No reference function was passed to check_function. If you only want to "
            "check gradients numerically, pass numerical_grads=True explicitly."
        )

    if numerical_grads is None:
        numerical_grads = 'if_possible'

    if numerical_grads not in [False, True, 'if_possible']:
        raise ValueError(
            "numerical_grads must be a bool or 'if_possible', not {}".format(
                numerical_grads))

    if additional_params is None:
        additional_params = {}

    input_vars = symbol.list_input_variables()
    input_dict = {x.attr('name'): x for x in input_vars}

    if grad_input_vars is None:
        grad_input_vars = sorted(input_vars, key=lambda x: x.attr('name'))
    else:
        grad_input_vars = [
            input_dict[x] if isinstance(x, str) else x for x in grad_input_vars
        ]

    in_range = _dict_var_to_dict_str(in_range)
    values = _dict_var_to_dict_str(values)

    out_len = len(symbol.list_output_names())

    # Infer the output shapes and dtypes, and preprocess the shape and dtype params
    forward_graph, shape, dtype, out_shapes, out_dtypes = \
        infer_shapes_dtypes(nnvm.graph.create(symbol), shape=shape, dtype=dtype,
                            fallback_dtype='float32')

    if not all(out_shapes) or not all(out_dtypes):
        if not quiet:
            print(forward_graph.ir(join_node_attrs=['shape', 'dtype']))
        raise ValueError("Could not infer shapes or dtypes for outputs.\n"
                         "out_shapes = {}\nout_dtypes = {}".format(
                             out_shapes, out_dtypes))

    backward_graph = None

    # If we want gradients, we have to recreate the graph, but now with gradient computations
    # Note that here we need out_shapes for defining the shape of head grads, so we have to
    # create the graph twice
    if backward is not None or numerical_grads:
        try:
            head_grads_symbols = [
                nnvm.symbol.Variable("head_grads_" + str(i),
                                     shape=out_shapes[i],
                                     dtype=DTYPE_TO_TCODE[out_dtypes[i]])
                for i in range(out_len)
            ]
            grad_symbols = graph_util.gradients([symbol],
                                                grad_input_vars,
                                                grad_ys=head_grads_symbols)
            # Sometimes grads do not depend on head_grads, so head_grads does not appear
            # in the variable list; adding it manually prevents this, making things a bit easier
            backward_graph = \
                nnvm.graph.create(nnvm.symbol.Group([symbol] + grad_symbols + head_grads_symbols))

            backward_graph, shape, dtype, out_shapes, out_dtypes = \
                infer_shapes_dtypes(backward_graph, shape=shape, dtype=dtype,
                                    fallback_dtype='float32')
        except nnvm._base.NNVMError as err:
            if backward is None and numerical_grads == "if_possible":
                logging.warning("Won't check gradients because: %s",
                                str(err).split('\n', 1)[0])
                numerical_grads = False
                backward_graph = None
            else:
                raise

    main_graph = backward_graph if backward_graph is not None else forward_graph

    # Generate random data for inputs (including head_grads)

    np_inputs = {}

    for x in main_graph.symbol.list_input_variables():
        x_name = x.attr('name')
        x_shape = shape[x_name]
        x_dtype = dtype[x_name]

        if values is not None and x_name in values:
            np_inputs[x_name] = values[x_name].astype(x_dtype)
            continue

        low = -1.0
        high = 1.0
        if in_range is not None:
            if isinstance(in_range, dict):
                if x_name in in_range:
                    low = in_range[x_name][0]
                    high = in_range[x_name][1]
            else:
                low = in_range[0]
                high = in_range[1]

        np_inputs[x_name] = np.random.uniform(size=x_shape, low=low,
                                              high=high).astype(x_dtype)

    np_inputs_without_head_grads = {
        k: np_inputs[k]
        for k in np_inputs if not k.startswith('head_grads_')
    }

    nothing_was_done = True

    # Compute and compare the results
    for target, ctx in ctx_list():
        if exclude_targets is not None:
            if target in exclude_targets or str(target) in exclude_targets:
                logging.info("Skipping target = %s, ctx = %s", target, ctx)
                continue
        if only_targets is not None:
            if target not in only_targets and str(target) not in only_targets:
                logging.info("Skipping target = %s, ctx = %s", target, ctx)
                continue

        logging.info("Checking computation on target = %s, ctx = %s", target,
                     ctx)

        debug_stage = None

        try:
            nnvm_res = None

            debug_stage = "compiling"
            main_function = graph_to_function(main_graph, target, ctx)

            # nnvm_res contains the output and gradients (if they are needed)
            debug_stage = "running"
            nnvm_res = main_function(**np_inputs)

            try:
                logging.debug("checking to_relay conversion")
                inputs = np_inputs_without_head_grads.copy()
                func, inputs = to_relay(main_graph,
                                        shape,
                                        dtype,
                                        params=inputs)
                with relay.build_config(opt_level=3):
                    graph, lib, params = relay.build(func, target=target)
                m = graph_runtime.create(graph, lib, ctx)
                m.set_input(**inputs)
                m.set_input(**params)
                m.run()
                for i in range(out_len):
                    relay_out = m.get_output(i).asnumpy()
                    tvm.testing.assert_allclose(nnvm_res[i],
                                                relay_out,
                                                atol=atol,
                                                rtol=rtol)
            except NotImplementedError as err:
                # the NNVM operator is not supported yet
                logging.warning(err)

            if backward_graph is not None:
                grad_var_names = [x.attr('name') for x in grad_input_vars]
                nnvm_grads = {
                    x: v
                    for x, v in zip(grad_var_names, nnvm_res[out_len:])
                }

            if forward is not None:
                nothing_was_done = False
                debug_stage = "checking forward computation"
                logging.debug(debug_stage)

                params = {}
                params.update(np_inputs_without_head_grads)
                params.update(additional_params)
                numpy_res = forward(**params)

                if isinstance(numpy_res, tuple):
                    numpy_res = list(numpy_res)

                if not isinstance(numpy_res, list):
                    numpy_res = [numpy_res]

                if len(numpy_res) != out_len:
                    raise ValueError(
                        "Forward function returned {} values, but "
                        "the nnvm graph returns {} values".format(
                            len(numpy_res), out_len))

                for i in range(out_len):
                    tvm.testing.assert_allclose(nnvm_res[i],
                                                numpy_res[i],
                                                atol=atol,
                                                rtol=rtol)

            if backward is not None:
                nothing_was_done = False
                debug_stage = "checking gradients"
                logging.debug(debug_stage)

                np_head_grads = [
                    np_inputs["head_grads_" + str(i)] for i in range(out_len)
                ]

                if out_len == 1:
                    np_head_grads = np_head_grads[0]

                params = {'head_grads': np_head_grads}
                params.update(np_inputs_without_head_grads)
                params.update(additional_params)
                numpy_grads = backward(**params)

                if not isinstance(numpy_grads, dict):
                    if isinstance(numpy_grads, tuple):
                        numpy_grads = list(numpy_grads)
                    if not isinstance(numpy_grads, list):
                        numpy_grads = [numpy_grads]
                    numpy_grads = {
                        x: v
                        for x, v in zip(grad_var_names, numpy_grads)
                    }
                    if len(numpy_grads) != len(grad_var_names):
                        raise ValueError(
                            "The backward function returns a list of gradients which "
                            "does not contain gradients for these variables: {}"
                            .format(set(grad_var_names) - set(numpy_grads)))

                for x_name in numpy_grads:
                    tvm.testing.assert_allclose(nnvm_grads[x_name],
                                                numpy_grads[x_name],
                                                atol=atol,
                                                rtol=rtol)

            if numerical_grads:
                nothing_was_done = False
                debug_stage = "checking gradients numerically"
                logging.debug(debug_stage)

                forward_function = graph_to_function(forward_graph, target,
                                                     ctx)

                # Since the result may be non-scalar, we have to put another operation on the top,
                # so we just multiple by the randomly generated head_grads and then sum everything.
                # This way we can reuse the gradient values which has been already computed.
                def scalar_function(**kwargs):
                    res = forward_function(**kwargs)
                    return np.sum([
                        np.dot(np_inputs['head_grads_' + str(i)].ravel(),
                               res[i].ravel()) for i in range(out_len)
                    ])

                if numerical_grads_params is None:
                    numerical_grads_params = {}

                check_numerical_grads(
                    scalar_function,
                    input_values=np_inputs_without_head_grads,
                    grad_values=nnvm_grads,
                    **numerical_grads_params)

        except:
            if not quiet:
                print(
                    "\ncheck_function failed while {}, here is the main graph".
                    format(debug_stage))
                print(main_graph.ir(join_node_attrs=['shape', 'dtype']))
                if nnvm_res is not None:
                    print("Generated inputs:")
                    print(np_inputs)
                    print()
            raise

    if nothing_was_done:
        logging.warning(
            "Nothing was done in check_function. Check ctx_list().")