Esempio n. 1
0
def test_gradgrad_tensor(func, optimized, *args):
    """Test gradients of functions with TFE signatures."""
    def tangent_func():
        df = tangent.grad(func,
                          motion='joint',
                          optimized=optimized,
                          verbose=True)
        ddf = tangent.grad(df,
                           motion='joint',
                           optimized=optimized,
                           verbose=True)
        dxx = ddf(*args)
        return tuple(t.numpy() for t in dxx)

    def reference_func():
        dxx = tfe.gradients_function(tfe.gradients_function(func))(*args)
        return tensors_to_numpy(tuple(t.numpy() for t in dxx))

    def backup_reference_func():
        func_ = as_numpy_sig(func)
        args_ = tensors_to_numpy(args)
        return utils.numeric_grad(utils.numeric_grad(func_))(*args_)

    utils.assert_result_matches_reference(
        tangent_func, reference_func, backup_reference_func,
        tolerance=1e-2)  # extra loose bounds for 2nd order grad
Esempio n. 2
0
def test_rev_tensor(func, motion, optimized, preserve_result, wrt, *args):
    """Test gradients of functions with TFE signatures."""
    def tangent_func():
        y = func(*args)
        if isinstance(y, (tuple, list)):
            init_grad = tuple(tf.ones_like(t) for t in y)
        else:
            init_grad = tf.ones_like(y)
        df = grad(func,
                  motion=motion,
                  optimized=optimized,
                  preserve_result=preserve_result,
                  wrt=wrt,
                  verbose=True)
        if motion == 'joint':
            # TODO: This won't work if func has default args unspecified.
            dx = df(*args + (init_grad, ))
        else:
            dx = df(*args, init_grad=init_grad)
        return tensors_to_numpy(dx)

    def reference_func():
        gradval = tensors_to_numpy(
            tfe.gradients_function(func, params=wrt)(*args))
        if preserve_result:
            val = tensors_to_numpy(func(*args))
            if isinstance(gradval, (tuple)):
                return gradval + (val, )
            return gradval, val
        else:
            return gradval

    def backup_reference_func():
        func_ = as_numpy_sig(func)
        args_ = tensors_to_numpy(args)
        gradval = utils.numeric_grad(utils.numeric_grad(func_))(*args_)
        if preserve_result:
            val = tensors_to_numpy(func(*args))
            return gradval, val
        else:
            return gradval

    utils.assert_result_matches_reference(
        tangent_func,
        reference_func,
        backup_reference_func,
        # Some ops like tf.divide diverge significantly due to what looks like
        # numerical instability.
        tolerance=1e-5)
def _test_gradgrad_array(func, optimized, *args):
    """Test gradients of functions with NumPy-compatible signatures."""
    def tangent_func():
        func.__globals__['np'] = np
        df = tangent.grad(func, optimized=optimized, verbose=True)
        ddf = tangent.grad(df, optimized=optimized, verbose=True)
        return ddf(*args)

    def reference_func():
        func.__globals__['np'] = ag_np
        return ag_grad(ag_grad(func))(*args)

    def backup_reference_func():
        return utils.numeric_grad(utils.numeric_grad(func))(*args)

    utils.assert_result_matches_reference(
        tangent_func, reference_func, backup_reference_func,
        tolerance=1e-2)  # extra loose bounds for 2nd order grad
Esempio n. 4
0
def test_forward_tensor(func, wrt, *args):
  """Test gradients of functions with TFE signatures."""

  def tangent_func():
    df = jvp(func, wrt=wrt, optimized=True, verbose=True)
    args_ = args + tuple(tf.ones_like(args[i]) for i in wrt)  # seed gradient
    return tensors_to_numpy(df(*args_))

  def reference_func():
    return tensors_to_numpy(tfe.gradients_function(func, params=wrt)(*args))

  def backup_reference_func():
    func_ = as_numpy_sig(func)
    args_ = tensors_to_numpy(args)
    return utils.numeric_grad(utils.numeric_grad(func_))(*args_)

  # TODO: Should results really be that far off?
  utils.assert_result_matches_reference(
      tangent_func, reference_func, backup_reference_func,
      tolerance=1e-4)