def function_tester(rng,
                    func,
                    ref_func,
                    inputs,
                    func_args=[],
                    func_kwargs={},
                    atol_f=1e-6,
                    atol_b=1e-3,
                    atol_accum=1e-6,
                    dstep=1e-3,
                    backward=None,
                    ctx=None,
                    func_name=None,
                    ref_grad=None,
                    disable_half_test=False,
                    atol_half=1e-1,
                    insert_identity=[],
                    disable_clear_no_need_grad_test=False,
                    auto_forward=False):
    """ Automatic testing of forward/backward pass of `func` by comparing it
    to the reference implementation in `ref_func`.

    Syntax of `ref_func`: inputs, parameters
    Syntax of `ref_grad`: inputs, output grads, parameters
    """

    if ctx is None:
        ctx = nn.Context()
    if backward is None:
        backward = [True for _ in inputs]

    # Create Variables
    # print('create_variable')

    def create_variables(inputs, backward):
        vinputs = []
        for i, b in zip(inputs, backward):
            if i is None:
                vinputs += [None]
                continue
            vinputs += [nn.Variable(i.shape, need_grad=b)]
            vinputs[-1].data.cast(i.dtype)[...] = i
        return vinputs

    # Half test
    if not disable_half_test:
        finputs = create_variables(inputs, backward)
        hinputs = create_variables(inputs, backward)
        half_test(rng,
                  func,
                  finputs,
                  hinputs,
                  func_args,
                  func_kwargs,
                  backward,
                  ctx,
                  func_name,
                  atol=atol_half)

    vinputs = create_variables(inputs, backward)
    # Checking forward
    # print('checking forward')
    with nn.context_scope(ctx), nn.auto_forward():
        o = func(*(vinputs + func_args), **func_kwargs)
    rinputs = copy.deepcopy(inputs)  # inputs for ref_func
    refs = ref_func(*(rinputs + func_args), **func_kwargs)

    refs = force_tuple(refs)
    o = force_tuple(o)
    assert len(o) == len(refs)
    for i, ref in enumerate(refs):
        res = o[i].d
        assert_allclose(ref,
                        res,
                        atol=atol_f,
                        err_msg="{} forward test fails".format(func_name))

    # Checking recomputation
    vinputs = create_variables(inputs, backward)
    recomputation_test(rng, func, vinputs, func_args, func_kwargs, ctx)

    # Checking forward(clear_no_need_grad=True)
    if not disable_clear_no_need_grad_test:
        clear_no_need_grad_tester(rng, func, inputs, func_args, func_kwargs,
                                  backward, atol_f, ctx, func_name,
                                  insert_identity, auto_forward)

    # Checking function name
    try:
        import function_test_callback
        result = create_function_nnp(vinputs, o, func_name, func_args,
                                     func_kwargs)
        if result is not None:
            function_test_callback.callback(func_name, *result)
    except UnboundLocalError:
        pass
    except IndexError:
        pass
    except ImportError:
        pass

    # print('checking function name')
    if func_name is not None:
        assert o[0].parent.name == func_name

    # Checking backward
    # print('checking backward')
    if not True in backward:
        return

    # NNabla backward
    for v in vinputs:
        if v is None:
            continue
        if len(v.shape) == 0:
            v.g = randn(rng)
            continue
        v.g = randn(rng, *v.shape)
    # Verify grad
    vinputs = create_variables(inputs, backward)
    rinputs = copy.deepcopy(inputs)
    rinputs = [
        rinput if test else None for rinput, test in zip(rinputs, backward)
    ]
    vgrads = [randn(rng, *o_.shape) for o_ in o]

    def reset_ograds():
        '''
        Reset output grads everytime we call backward.
        This is required because the output grad might
        be inplaced and modified during backward operation.
        '''
        for ovar, g in zip(o, vgrads):
            ovar.g = g

    agrads, ngrads = compute_analytical_and_numerical_grad(o[0].parent,
                                                           vinputs,
                                                           o,
                                                           rinputs,
                                                           vgrads,
                                                           epsilon=dstep,
                                                           rng=rng,
                                                           ref_grad=ref_grad)
    if ref_grad is not None:
        rinputs = copy.deepcopy(inputs)
        doutputs = copy.deepcopy(vgrads)
        ngrads = ref_grad(*(rinputs + doutputs + func_args),
                          **func_kwargs,
                          need_grad_flags=backward)

    assert_allclose(
        ngrads,
        agrads,
        atol=atol_b,
        err_msg="{} backward w/o accumulation test fails".format(func_name))

    # Check if need_grad works
    for v, b in zip(vinputs, backward):
        if not b or v is None:
            continue
        v.grad.zero()
        v.need_grad = False
        reset_ograds()
        try:
            o[0].parent.forward(list(filter(lambda x: x is not None, vinputs)),
                                o)
            o[0].parent.backward(
                list(filter(lambda x: x is not None, vinputs)), o)
        except RuntimeError as e:
            continue  # TODO
        assert np.all(v.g == 0)

    # test accum=False
    for i in range(len(vinputs)):
        if vinputs[i] is None:
            continue
        v = vinputs[i]
        v.need_grad = backward[i]

    for i in range(len(vinputs)):
        if vinputs[i] is None:
            continue
        v = vinputs[i]

        if not backward[i]:
            continue
        f = o[0].parent

        # Prepare function inputs
        finputs = list(filter(lambda x: x is not None, vinputs))

        # Save accum gradient result
        g = randn(rng, *v.shape)
        v.g = g
        reset_ograds()
        f.forward(finputs, o)
        f.backward(finputs, o)
        true_g = v.g - g

        # Check accum=False
        accum = [j != i for j, vv in enumerate(vinputs) if vv is not None]
        v.g = randn(rng, *v.shape)
        reset_ograds()
        f.forward(finputs, o)
        f.backward(finputs, o, accum)
        assert_allclose(
            v.g,
            true_g,
            atol=atol_accum,
            err_msg="{} backward w/ accumulation test fails.".format(
                func_name))

        # Check accum=False with NaN gradient
        v.g = np.float32('nan')
        reset_ograds()
        f.forward(finputs, o)
        f.backward(finputs, o, accum)
        assert not np.any(np.isnan(v.g))
Beispiel #2
0
def function_tester(rng,
                    func,
                    ref_func,
                    inputs,
                    func_args=[],
                    func_kwargs={},
                    atol_f=1e-6,
                    atol_b=1e-3,
                    atol_accum=1e-6,
                    dstep=1e-3,
                    backward=None,
                    ctx=None,
                    func_name=None,
                    ref_grad=None,
                    disable_half_test=False,
                    atol_half=1e-1):
    """ Automatic testing of forward/backward pass of `func` by comparing it
    to the reference implementation in `ref_func`.

    Syntax of `ref_func`: inputs, parameters
    Syntax of `ref_grad`: inputs, output grads, parameters
    """

    if ctx is None:
        ctx = nn.Context()
    if backward is None:
        backward = [True for _ in inputs]

    # Create Variables
    # print('create_variable')

    def create_variables(inputs, backward):
        vinputs = []
        for i, b in zip(inputs, backward):
            if i is None:
                vinputs += [None]
                continue
            vinputs += [nn.Variable(i.shape, need_grad=b)]
            vinputs[-1].data.cast(i.dtype)[...] = i
        return vinputs

    # Half test
    if not disable_half_test:
        finputs = create_variables(inputs, backward)
        hinputs = create_variables(inputs, backward)
        half_test(rng,
                  func,
                  finputs,
                  hinputs,
                  func_args,
                  func_kwargs,
                  backward,
                  ctx,
                  func_name,
                  atol=atol_half)

    vinputs = create_variables(inputs, backward)
    # Checking forward
    # print('checking forward')
    with nn.context_scope(ctx), nn.auto_forward():
        o = func(*(vinputs + func_args), **func_kwargs)
    rinputs = copy.deepcopy(inputs)  # inputs for ref_func
    refs = ref_func(*(rinputs + func_args), **func_kwargs)

    refs = force_tuple(refs)
    o = force_tuple(o)
    assert len(o) == len(refs)
    for i, ref in enumerate(refs):
        res = o[i].d
        assert np.allclose(ref, res,
                           atol=atol_f), str(ArrayDiffStats(ref, res))

    # Checking function name
    try:
        import function_test_callback
        result = create_function_nnp(vinputs, o, func_name, func_args,
                                     func_kwargs)
        if result is not None:
            function_test_callback.callback(func_name, *result)
    except UnboundLocalError:
        pass
    except IndexError:
        pass
    except ImportError:
        pass

    # print('checking function name')
    if func_name is not None:
        assert o[0].parent.name == func_name

    # Checking backward
    # print('checking backward')
    if not True in backward:
        return

    # NNabla backward
    for v in vinputs:
        if v is None:
            continue
        if len(v.shape) == 0:
            v.g = rng.randn()
            continue
        v.g = rng.randn(*v.shape).astype(v.data.dtype)
    # Verify grad
    vinputs = create_variables(inputs, backward)
    rinputs = copy.deepcopy(inputs)
    rinputs = [
        rinput if test else None for rinput, test in zip(rinputs, backward)
    ]
    vgrads = [rng.randn(*o_.shape) for o_ in o]

    agrads, ngrads = compute_analytical_and_numerical_grad(o[0].parent,
                                                           vinputs,
                                                           o,
                                                           rinputs,
                                                           vgrads,
                                                           epsilon=dstep,
                                                           rng=rng,
                                                           ref_grad=ref_grad)
    if ref_grad is not None:
        rinputs = copy.deepcopy(inputs)
        doutputs = [o_.g for o_ in o]
        ngrads = ref_grad(*(rinputs + doutputs + func_args), **func_kwargs)

    assert np.allclose(ngrads, agrads,
                       atol=atol_b), str(ArrayDiffStats(ngrads, agrads))

    # Check if need_grad works
    for v, b in zip(vinputs, backward):
        if not b or v is None:
            continue
        v.g = 0
        v.need_grad = False
        try:
            o[0].parent.backward(
                list(filter(lambda x: x is not None, vinputs)), o)
        except RuntimeError as e:
            continue  # TODO
        assert np.all(v.g == 0)

    # test accum=False
    for i in range(len(vinputs)):
        if vinputs[i] is None:
            continue
        v = vinputs[i]
        v.need_grad = backward[i]

    for i in range(len(vinputs)):
        if vinputs[i] is None:
            continue
        v = vinputs[i]

        if not backward[i]:
            continue
        f = o[0].parent

        # If input's grad is inplaced, the test doesn't work correctly.
        if f.inplace_grad(i):
            continue

        # Prepare function inputs
        finputs = list(filter(lambda x: x is not None, vinputs))

        # Save accum gradient result
        g = rng.randn(*v.shape)
        v.g = g
        f.forward(finputs, o)
        f.backward(finputs, o)
        true_g = v.g - g

        # Check accum=False
        accum = [j != i for j, vv in enumerate(vinputs) if vv is not None]
        v.g = rng.randn(*v.shape)
        f.forward(finputs, o)
        f.backward(finputs, o, accum)
        assert np.allclose(v.g, true_g,
                           atol=atol_accum), str(ArrayDiffStats(v.g, true_g))

        # Check accum=False with NaN gradient
        v.g = np.float32('nan')
        f.forward(finputs, o)
        f.backward(finputs, o, accum)
        assert not np.any(np.isnan(v.g))