Exemple #1
0
    def testDuplicateAccumulator(self):
        x = constant_op.constant(2.)

        tensor_list = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                                 element_shape=ScalarShape())

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5.

        def Body(x, tl):
            # There is an accumulator in the loop already so we should not add
            # another.
            tl = list_ops.tensor_list_push_back(tl, x)
            return x**2., tl

        ret = while_loop_v2(Cond, Body, [x, tensor_list])

        for op in ops.get_default_graph().get_operations():
            if op.type == "While":
                while_op = op

        body_graph = while_v2._get_body_graph(while_op)
        # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
        x_input_t = body_graph.inputs[1]
        accumulator_count = len([
            c for c in x_input_t.consumers() if c.type == "TensorListPushBack"
        ])
        self.assertEqual(accumulator_count, 1)

        grad = gradients_impl.gradients(ret[0], x)
        with self.cached_session() as sess:
            self.assertEqual(sess.run(ret[0]), 16.)
            self.assertSequenceEqual(self.evaluate(grad), [32.])
 def GetAccumulatorForInputAtIndex(while_op, idx):
   body_graph = while_v2._get_body_graph(while_op)
   y_input_t = body_graph.inputs[idx]
   push_back_node = [c for c in y_input_t.consumers()
                     if c.type == "TensorListPushBack"][0]
   output_idx = body_graph.outputs.index(push_back_node.outputs[0])
   return while_op.outputs[output_idx]
  def testDuplicateAccumulator(self):
    x = constant_op.constant(2.)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=ScalarShape())

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5.

    def Body(x, tl):
      # There is an accumulator in the loop already so we should not add
      # another.
      tl = list_ops.tensor_list_push_back(tl, x)
      return x**2., tl

    ret = while_loop_v2(Cond, Body, [x, tensor_list])

    for op in ops.get_default_graph().get_operations():
      if op.type == "While":
        while_op = op

    body_graph = while_v2._get_body_graph(while_op)
    # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
    x_input_t = body_graph.inputs[1]
    accumulator_count = len(
        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
    self.assertEqual(accumulator_count, 1)

    grad = gradients_impl.gradients(ret[0], x)
    with self.cached_session() as sess:
      self.assertEqual(sess.run(ret[0]), 16.)
      self.assertSequenceEqual(sess.run(grad), [32.])
 def GetAccumulatorForInputAtIndex(while_op, idx):
   body_graph = while_v2._get_body_graph(while_op)
   y_input_t = body_graph.inputs[idx]
   push_back_node = [c for c in y_input_t.consumers()
                     if c.type == "TensorListPushBack"][0]
   output_idx = body_graph.outputs.index(push_back_node.outputs[0])
   return while_op.outputs[output_idx]
Exemple #5
0
    def testTensorListOutputElementShape(self, shape):
        self.skipTest("b/115982901")
        x = constant_op.constant(2.)
        y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
        ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u),
                            [x, y])

        # Get the TensorList output of While op containing the accumulated values
        # of y.
        while_op = ret[0].op
        body_graph = while_v2._get_body_graph(while_op)
        # body_graph.inputs: [counter_arg, x_arg, y_arg, *accumulators]
        y_input_t = body_graph.inputs[2]
        push_back_node = [
            c for c in y_input_t.consumers() if c.type == "TensorListPushBack"
        ][0]
        output_idx = body_graph.outputs.index(push_back_node.outputs[0])
        output = while_op.outputs[output_idx]

        _, val = list_ops.tensor_list_pop_back(output,
                                               element_dtype=dtypes.float32)
        self.assertEqual(val.shape, tensor_shape.TensorShape(shape))