Exemple #1
0
  def testVjp(self, has_aux):
    x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10]))
    y_shape = (tf.TensorShape([]))
    dtype = np.float32

    def f(a, b):
      y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
      if has_aux:
        return y, tf_np.asarray(1)
      else:
        return y

    rng = tf.random.Generator.from_seed(1234)
    x, dy_list = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype),
                                       [x_shape, [y_shape] * 2])
    tf_x = to_tf(x)
    outputs = extensions.vjp(f, *x, has_aux=has_aux)
    if has_aux:
      y, vjp, aux = outputs
    else:
      y, vjp = outputs
    with tf.GradientTape(persistent=True) as tape:
      tape.watch(tf_x)
      outputs = f(*x)
      if has_aux:
        expected_y, expected_aux = outputs
        self.assertAllClose(to_tf(expected_aux), to_tf(aux))
      else:
        expected_y = outputs
    self.assertAllClose(to_tf(expected_y), to_tf(y))
    for dy in dy_list:
      expected_dx = tape.gradient(
          to_tf(expected_y), tf_x, output_gradients=to_tf(dy))
      self.assertAllClose(expected_dx, to_tf(vjp(dy)))
Exemple #2
0
  def testCustomGrad(self):
    """Test for custom_grad."""
    x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10]))
    y_shape = (tf.TensorShape([]))
    dtype = np.float32
    scale1 = 5.0
    scale2 = 6.0

    def fwd(a, b):
      return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

    @extensions.custom_grad
    def f(a, b):
      y = fwd(a, b)

      def vjp(dy):
        return dy * scale1 * a, dy * scale2 * b

      return y, vjp

    rng = tf.random.Generator.from_seed(1234)
    x, dy = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype),
                                  [x_shape, y_shape])
    expected_y = fwd(*x)
    expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1])
    y, vjp = extensions.vjp(f, *x)
    dx = vjp(dy)
    self.assertAllClose(to_tf(expected_y), to_tf(y))
    self.assertAllClose(to_tf(expected_dx), to_tf(dx))
Exemple #3
0
 def run(x, dy):
   y, vjp = extensions.vjp(f, *x)
   dx = vjp(dy)
   return dx, y