コード例 #1
0
  def testJVPFunctionWithBatchOfTangents(self):
    add_outputs = (constant_op.constant(4.),)
    jvp_flat = forwardprop._jvp_dispatch(
        op_name="Add",
        attr_tuple=(),
        inputs=(constant_op.constant(1.), constant_op.constant(3.)),
        outputs=add_outputs,
        tangents=(
            constant_op.constant([1., 2., 3.]),
            constant_op.constant([4., 5., 6.]),
        ),
        use_batch=True)

    # Using evaluate and asserting with just a list works too
    # but the output is more explicit this way
    self.assertAllClose([constant_op.constant([1. + 4., 2. + 5., 3. + 6.])],
                        jvp_flat)

    mul_outputs = (constant_op.constant([20.]),)
    jvp_flat = forwardprop._jvp_dispatch(
        op_name="Mul",
        attr_tuple=(),
        inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
        outputs=mul_outputs,
        tangents=(
            constant_op.constant([[1.], [0.], [1.]]),
            constant_op.constant([[0.], [1.], [1.]]),
        ),
        use_batch=True)
    self.assertAllClose([constant_op.constant([[5.], [4.], [5. + 4.]])],
                        jvp_flat)
コード例 #2
0
  def testJVPFunctionRaisesError(self):
    sum_outputs = (constant_op.constant(6.),)

    with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"):
      forwardprop._jvp_dispatch(
          op_name="Add",
          attr_tuple=(),
          inputs=(constant_op.constant(2.), constant_op.constant(4.)),
          outputs=sum_outputs,
          tangents=(constant_op.constant([1., 2.]),
                    constant_op.constant([[1.], [2.]])),
          use_batch=True)
コード例 #3
0
  def testJVPFunction(self):
    add_outputs = (constant_op.constant(4.),)
    vp, = forwardprop._jvp_dispatch(
        op_name="Add",
        attr_tuple=(),
        inputs=(constant_op.constant(1.), constant_op.constant(3.)),
        outputs=add_outputs,
        tangents=(
            constant_op.constant(1.),
            constant_op.constant(5.),
        ))
    self.assertAllClose(1. + 5., self.evaluate(vp))

    mul_outputs = (constant_op.constant([20.]),)
    vp, = forwardprop._jvp_dispatch(
        op_name="Mul",
        attr_tuple=(),
        inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
        outputs=mul_outputs,
        tangents=(
            constant_op.constant([2.]),
            constant_op.constant([3.]),
        ))
    self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))