Пример #1
0
  def testIndexedSlicesWithShapeGradientInWhileLoop(self):
    with self.test_session() as sess:
      num_steps = 9

      inputs = tf.placeholder(dtype="float32", shape=[num_steps])
      initial_outputs = tf.TensorArray(dtype="float32", size=num_steps)
      initial_i = tf.constant(0, dtype="int32")

      def Cond(i, _):
        return i < num_steps

      def Body(i, outputs):
        x = tf.gather(inputs, i)
        outputs = outputs.write(i, x)
        return i + 1, outputs

      _, outputs = tf.while_loop(Cond, Body, [initial_i, initial_outputs])

      outputs = tf.reduce_sum(outputs.pack())
      r = tf.gradients([outputs], [inputs])[0]
      grad_wr_inputs = ops.convert_to_tensor(r)
      o, grad = sess.run([outputs, grad_wr_inputs],
                         feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
      self.assertEquals(o, 20)
      self.assertAllEqual(grad, [1] * num_steps)
Пример #2
0
    def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
        for dtype in [dtypes.float32, dtypes.float64]:
            with self.test_session() as sess:
                inputs = tf.placeholder(dtype=dtype)
                initial_outputs = tf.TensorArray(dtype=dtype,
                                                 dynamic_size=True,
                                                 size=1)
                initial_i = tf.constant(0, dtype=dtypes.int32)

                def Cond(i, _):
                    return i < tf.size(inputs)  # pylint: disable=cell-var-from-loop

                def Body(i, outputs):
                    x = tf.gather(inputs, i)  # pylint: disable=cell-var-from-loop
                    outputs = outputs.write(i, x)
                    return i + 1, outputs

                _, outputs = tf.while_loop(Cond, Body,
                                           [initial_i, initial_outputs])

                outputs = tf.reduce_sum(outputs.pack())
                r = tf.gradients([outputs], [inputs])[0]
                grad_wr_inputs = ops.convert_to_tensor(r)
                o, grad = sess.run([outputs, grad_wr_inputs],
                                   feed_dict={inputs: [1, 3, 2]})
                self.assertEquals(o, 6)
                self.assertAllEqual(grad, [1] * 3)