Example #1
0
  def testCustomGrad(self):

    def fn(a, b, c):
      return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c)

    def grad_fn(inputs, trainable_variables, unused_outputs,
                unused_grad_outputs):
      grad_inputs = [
          array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)
      ]
      grad_vars = [
          array_ops.ones_like(t) * (i + len(inputs) + 1.)
          for i, t in enumerate(trainable_variables)
      ]
      return grad_inputs, grad_vars

    a = random_ops.random_uniform([11, 6])
    b = random_ops.random_uniform([11, 7])
    c = random_ops.random_uniform([7, 10])
    w = random_ops.random_uniform([6, 10])
    out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c)
    loss = math_ops.reduce_mean(out)
    grads = gradients_impl.gradients(
        loss, [a, b, c, variables.trainable_variables()[0]])
    expected_grads = [
        array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w])
    ]
    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())
      g_val, eg_val = sess.run([grads, expected_grads])
      for g1, g2 in zip(g_val, eg_val):
        self.assertAllClose(g1, g2)
    def testCustomGrad(self):
        def fn(a, b, c):
            return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(
                b, c)

        def grad_fn(inputs, trainable_variables, unused_outputs,
                    unused_grad_outputs):
            grad_inputs = [
                array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)
            ]
            grad_vars = [
                array_ops.ones_like(t) * (i + len(inputs) + 1.)
                for i, t in enumerate(trainable_variables)
            ]
            return grad_inputs, grad_vars

        a = random_ops.random_uniform([11, 6])
        b = random_ops.random_uniform([11, 7])
        c = random_ops.random_uniform([7, 10])
        w = random_ops.random_uniform([6, 10])
        out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c)
        loss = math_ops.reduce_mean(out)
        grads = gradients_impl.gradients(
            loss, [a, b, c, variables.trainable_variables()[0]])
        expected_grads = [
            array_ops.ones_like(t) * (i + 1.)
            for i, t in enumerate([a, b, c, w])
        ]
        with self.test_session() as sess:
            sess.run(variables.global_variables_initializer())
            g_val, eg_val = sess.run([grads, expected_grads])
            for g1, g2 in zip(g_val, eg_val):
                self.assertAllClose(g1, g2)
    def testCorrectness(self):

        w = random_ops.random_uniform([6, 10])

        def fn(a, b, c):
            return core_layers.dense(a,
                                     10,
                                     use_bias=False,
                                     kernel_initializer=lambda shape, dtype,
                                     partition_info: w) + math_ops.matmul(
                                         b, c)

        def grad_fn(inputs, trainable_variables, outputs, grad_outputs):
            outputs = outputs[0]
            grad_outputs = grad_outputs[0]
            grad_inputs = gradients_impl.gradients(outputs,
                                                   inputs,
                                                   grad_ys=grad_outputs)
            grad_vars = gradients_impl.gradients(outputs,
                                                 trainable_variables,
                                                 grad_ys=grad_outputs)
            return grad_inputs, grad_vars

        custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)

        a = random_ops.random_uniform([11, 6])
        b = random_ops.random_uniform([11, 7])
        c = random_ops.random_uniform([7, 10])

        out = fn(a, b, c)
        custom_out = custom_fn(a, b, c)
        self.assertEqual(out.get_shape().as_list(),
                         custom_out.get_shape().as_list())

        loss = math_ops.reduce_mean(out)
        custom_loss = math_ops.reduce_mean(custom_out)

        grads = gradients_impl.gradients(loss, [a, b, c] +
                                         [variables.trainable_variables()[0]])
        custom_grads = gradients_impl.gradients(
            custom_loss, [a, b, c] + [variables.trainable_variables()[1]])

        with self.test_session() as sess:
            sess.run(variables.global_variables_initializer())
            out_val, custom_out_val, grads_val, custom_grads_val = sess.run(
                [out, custom_out, grads, custom_grads])
            self.assertAllClose(out_val, custom_out_val)
            for g1, g2 in zip(grads_val, custom_grads_val):
                self.assertAllClose(g1, g2)
Example #4
0
  def testCorrectness(self):

    w = random_ops.random_uniform([6, 10])

    def fn(a, b, c):
      return core_layers.dense(
          a,
          10,
          use_bias=False,
          kernel_initializer=lambda shape, dtype, partition_info: w
      ) + math_ops.matmul(b, c)

    def grad_fn(inputs, trainable_variables, outputs, grad_outputs):
      outputs = outputs[0]
      grad_outputs = grad_outputs[0]
      grad_inputs = gradients_impl.gradients(
          outputs, inputs, grad_ys=grad_outputs)
      grad_vars = gradients_impl.gradients(
          outputs, trainable_variables, grad_ys=grad_outputs)
      return grad_inputs, grad_vars

    custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)

    a = random_ops.random_uniform([11, 6])
    b = random_ops.random_uniform([11, 7])
    c = random_ops.random_uniform([7, 10])

    out = fn(a, b, c)
    custom_out = custom_fn(a, b, c)
    self.assertEqual(out.get_shape().as_list(),
                     custom_out.get_shape().as_list())

    loss = math_ops.reduce_mean(out)
    custom_loss = math_ops.reduce_mean(custom_out)

    grads = gradients_impl.gradients(
        loss, [a, b, c] + [variables.trainable_variables()[0]])
    custom_grads = gradients_impl.gradients(
        custom_loss, [a, b, c] + [variables.trainable_variables()[1]])

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())
      out_val, custom_out_val, grads_val, custom_grads_val = sess.run(
          [out, custom_out, grads, custom_grads])
      self.assertAllClose(out_val, custom_out_val)
      for g1, g2 in zip(grads_val, custom_grads_val):
        self.assertAllClose(g1, g2)