Example #1
0
 def _createWhile(self, name):
     """Helper function testDefaultName."""
     output = while_v2.while_loop(lambda i: i < 3, lambda i: i + 1,
                                  [constant_op.constant(0)])
     while_op = output.op.inputs[0].op
     self.assertEqual(while_op.type, "While")
     return while_op
 def _createWhile(self, name):
   """Helper function testDefaultName."""
   output = while_v2.while_loop(lambda i: i < 3, lambda i: i + 1,
                                [constant_op.constant(0)])
   while_op = output.op.inputs[0].op
   self.assertEqual(while_op.type, "While")
   return while_op
Example #3
0
 def _createWhile(self, name):
     """Helper function testDefaultName."""
     output = while_v2.while_loop(lambda i: i < 3,
                                  lambda i: i + 1,
                                  [constant_op.constant(0)],
                                  return_same_structure=False)
     while_op = output.op.inputs[0].op
     if compat.forward_compatible(2019, 8, 23):
         self.assertEqual(while_op.type, "StatelessWhile")
     return while_op
Example #4
0
 def while_loop(self,
                cond,
                body,
                loop_vars,
                shape_invariants=None,
                parallel_iterations=10,
                back_prop=False,
                swap_memory=False,
                maximum_iterations=None,
                return_same_structure=False,
                use_while_v2=False):
     Module.global_scope.append('while')
     if use_while_v2:
         if tf.__version__.startswith('1.13'):
             x = while_v2.while_loop(
                 cond=cond,
                 body=body,
                 loop_vars=loop_vars,
                 shape_invariants=shape_invariants,
                 maximum_iterations=maximum_iterations,
                 return_same_structure=return_same_structure)
         else:
             x = while_v2.while_loop(cond=cond,
                                     body=body,
                                     loop_vars=loop_vars)
     else:
         x = tf.while_loop(cond=cond,
                           body=body,
                           loop_vars=loop_vars,
                           shape_invariants=shape_invariants,
                           parallel_iterations=parallel_iterations,
                           back_prop=back_prop,
                           swap_memory=swap_memory,
                           maximum_iterations=maximum_iterations,
                           return_same_structure=return_same_structure)
     Module.global_scope.pop()
     return x
Example #5
0
    def testForwardPassRewrite(self):
        x = constant_op.constant(1.0, name="x")
        output = while_v2.while_loop(lambda x: x < 10.0, lambda x: x * 2.0,
                                     [x])[0]
        while_op = output.op.inputs[0].op
        self.assertEqual(while_op.type, "While")
        # outputs = [loop_counter, x]
        self.assertLen(while_op.outputs, 2)

        gradients_impl.gradients(output, x)
        # while_op should have been rewritten to output 2.0 intermediate.
        # outputs = [loop_counter, x, 2.0_accumulator, x_accumulator]
        self.assertLen(while_op.outputs, 4)

        gradients_impl.gradients(output, x)
        # Computing the gradient again shouldn't rewrite while_op again.
        self.assertLen(while_op.outputs, 4)
Example #6
0
  def testForwardPassRewrite(self):
    x = constant_op.constant(1.0, name="x")
    output = while_v2.while_loop(lambda x: x < 10.0,
                                 lambda x: x * 2.0,
                                 [x])[0]
    while_op = output.op.inputs[0].op
    self.assertEqual(while_op.type, "While")
    # outputs = [loop_counter, max_iters, x]
    self.assertLen(while_op.outputs, 3)

    gradients_impl.gradients(output, x)
    # while_op should have been rewritten to output 2.0 intermediate.
    # outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
    self.assertLen(while_op.outputs, 5)

    gradients_impl.gradients(output, x)
    # Computing the gradient again shouldn't rewrite while_op again.
    self.assertLen(while_op.outputs, 5)
Example #7
0
    def F():
      with ops.name_scope("foo"):

        def Cond(unused_i):
          with ops.name_scope("cond"):
            actual_name_scope = ops.get_name_scope()
            expected_name_scope = "foo/while/cond"
            assert actual_name_scope == expected_name_scope, (
                "%s does not match %s" %
                (actual_name_scope, expected_name_scope))
          return False

        def Body(i):
          with ops.name_scope("body"):
            actual_name_scope = ops.get_name_scope()
            expected_name_scope = "foo/while/body"
            assert actual_name_scope == expected_name_scope, (
                "%s does not match %s" %
                (actual_name_scope, expected_name_scope))
          return i

        return while_v2.while_loop(Cond, Body, [0.])
    def test_composite_tensor(self):
        with self.session(graph=ops.Graph()) as sess:
            sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
            operator, mat = self.operator_and_matrix(
                shapes_info, dtype, use_placeholder=use_placeholder)
            self.assertIsInstance(operator, composite_tensor.CompositeTensor)

            flat = nest.flatten(operator, expand_composites=True)
            unflat = nest.pack_sequence_as(operator,
                                           flat,
                                           expand_composites=True)
            self.assertIsInstance(unflat, type(operator))

            # Input the operator to a `tf.function`.
            x = self.make_x(operator, adjoint=False)
            op_y = def_function.function(lambda op: op.matmul(x))(unflat)
            mat_y = math_ops.matmul(mat, x)

            if not use_placeholder:
                self.assertAllEqual(mat_y.shape, op_y.shape)

            # Test while_loop.
            def body(op):
                return type(op)(**op.parameters),

            op_out, = while_v2.while_loop(cond=lambda _: True,
                                          body=body,
                                          loop_vars=(operator, ),
                                          maximum_iterations=3)
            loop_y = op_out.matmul(x)

            op_y_, loop_y_, mat_y_ = sess.run([op_y, loop_y, mat_y])
            self.assertAC(op_y_, mat_y_)
            self.assertAC(loop_y_, mat_y_)

            # Ensure that the `TypeSpec` can be encoded.
            struct_coder = nested_structure_coder.StructureCoder()
            struct_coder.encode_structure(operator._type_spec)  # pylint: disable=protected-access
 def model(x):
     return while_v2.while_loop(lambda v: v < 4.,
                                lambda v: v * v, [x],
                                return_same_structure=False,
                                name="while_1")  # x**2