Пример #1
0
 def GetAccumulatorForInputAtIndex(while_op, idx):
   body_graph = while_v2._get_graph(while_op, "body")
   y_input_t = body_graph.inputs[idx]
   push_back_node = [c for c in y_input_t.consumers()
                     if c.type == "TensorListPushBack"][0]
   output_idx = body_graph.outputs.index(push_back_node.outputs[0])
   return while_op.outputs[output_idx]
Пример #2
0
  def testDuplicateAccumulator(self):
    x = constant_op.constant(2.)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=ScalarShape())

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5.

    def Body(x, tl):
      # There is an accumulator in the loop already so we should not add
      # another.
      tl = list_ops.tensor_list_push_back(tl, x)
      return x**2., tl

    ret = while_loop_v2(
        Cond, Body, [x, tensor_list], return_same_structure=False)

    for op in ops.get_default_graph().get_operations():
      if op.type == "While" or op.type == "StatelessWhile":
        while_op = op

    body_graph = while_v2._get_graph(while_op, "body")
    x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
    x_input_t = body_graph.inputs[x_input_index]
    accumulator_count = len(
        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
    self.assertEqual(accumulator_count, 1)

    grad = gradients_impl.gradients(ret[0], x)
    with self.cached_session() as sess:
      self.assertEqual(sess.run(ret[0]), 16.)
      self.assertSequenceEqual(self.evaluate(grad), [32.])
Пример #3
0
 def GetAccumulatorForInputAtIndex(while_op, idx):
   body_graph = while_v2._get_graph(while_op, "body")
   y_input_t = body_graph.inputs[idx]
   push_back_node = [c for c in y_input_t.consumers()
                     if c.type == "TensorListPushBack"][0]
   output_idx = body_graph.outputs.index(push_back_node.outputs[0])
   return while_op.outputs[output_idx]
Пример #4
0
  def testDuplicateAccumulator(self):
    x = constant_op.constant(2.)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=ScalarShape())

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5.

    def Body(x, tl):
      # There is an accumulator in the loop already so we should not add
      # another.
      tl = list_ops.tensor_list_push_back(tl, x)
      return x**2., tl

    ret = while_loop_v2(
        Cond, Body, [x, tensor_list], return_same_structure=False)

    for op in ops.get_default_graph().get_operations():
      if op.type == "While":
        while_op = op

    body_graph = while_v2._get_graph(while_op, "body")
    x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
    x_input_t = body_graph.inputs[x_input_index]
    accumulator_count = len(
        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
    self.assertEqual(accumulator_count, 1)

    grad = gradients_impl.gradients(ret[0], x)
    with self.cached_session() as sess:
      self.assertEqual(sess.run(ret[0]), 16.)
      self.assertSequenceEqual(self.evaluate(grad), [32.])
Пример #5
0
  def testDoNotAccumulateConstNodes(self):

    def Body(v):
      return v * 2.0

    v0 = constant_op.constant(2.)
    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
    # Gradients computation has the side-effect of updating the forward op
    # which is what we want to test.
    unused_grad = gradients_impl.gradients(ret, [v0])[0]
    # ret is separated from the `While` op by an `Identity` so we skip over
    # that.
    forward_while_op = ret.op.inputs[0].op
    body_graph = while_v2._get_graph(forward_while_op, "body")
    push_back_nodes = [
        o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
    ]
    # Gradient of `Mul` requires accumulating both its inputs. But since one
    # of those is a Const (2.0), we should have just one accumulator.
    self.assertLen(push_back_nodes, 1)
Пример #6
0
 def _assertNotAccumulated(self, while_op, index):
     """Asserts that `while_op` input at `index` is not accumulated."""
     body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
     placeholder = body_graph.inputs[index]
     self.assertNotIn("TensorListPushBack",
                      [op.type for op in placeholder.consumers()])