def _createWhile(self, name): """Helper function testDefaultName.""" output = while_v2.while_loop(lambda i: i < 3, lambda i: i + 1, [constant_op.constant(0)]) while_op = output.op.inputs[0].op self.assertEqual(while_op.type, "While") return while_op
def _createWhile(self, name): """Helper function testDefaultName.""" output = while_v2.while_loop(lambda i: i < 3, lambda i: i + 1, [constant_op.constant(0)], return_same_structure=False) while_op = output.op.inputs[0].op if compat.forward_compatible(2019, 8, 23): self.assertEqual(while_op.type, "StatelessWhile") return while_op
def while_loop(self, cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=False, swap_memory=False, maximum_iterations=None, return_same_structure=False, use_while_v2=False): Module.global_scope.append('while') if use_while_v2: if tf.__version__.startswith('1.13'): x = while_v2.while_loop( cond=cond, body=body, loop_vars=loop_vars, shape_invariants=shape_invariants, maximum_iterations=maximum_iterations, return_same_structure=return_same_structure) else: x = while_v2.while_loop(cond=cond, body=body, loop_vars=loop_vars) else: x = tf.while_loop(cond=cond, body=body, loop_vars=loop_vars, shape_invariants=shape_invariants, parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=maximum_iterations, return_same_structure=return_same_structure) Module.global_scope.pop() return x
def testForwardPassRewrite(self): x = constant_op.constant(1.0, name="x") output = while_v2.while_loop(lambda x: x < 10.0, lambda x: x * 2.0, [x])[0] while_op = output.op.inputs[0].op self.assertEqual(while_op.type, "While") # outputs = [loop_counter, x] self.assertLen(while_op.outputs, 2) gradients_impl.gradients(output, x) # while_op should have been rewritten to output 2.0 intermediate. # outputs = [loop_counter, x, 2.0_accumulator, x_accumulator] self.assertLen(while_op.outputs, 4) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite while_op again. self.assertLen(while_op.outputs, 4)
def testForwardPassRewrite(self): x = constant_op.constant(1.0, name="x") output = while_v2.while_loop(lambda x: x < 10.0, lambda x: x * 2.0, [x])[0] while_op = output.op.inputs[0].op self.assertEqual(while_op.type, "While") # outputs = [loop_counter, max_iters, x] self.assertLen(while_op.outputs, 3) gradients_impl.gradients(output, x) # while_op should have been rewritten to output 2.0 intermediate. # outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator] self.assertLen(while_op.outputs, 5) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite while_op again. self.assertLen(while_op.outputs, 5)
def F(): with ops.name_scope("foo"): def Cond(unused_i): with ops.name_scope("cond"): actual_name_scope = ops.get_name_scope() expected_name_scope = "foo/while/cond" assert actual_name_scope == expected_name_scope, ( "%s does not match %s" % (actual_name_scope, expected_name_scope)) return False def Body(i): with ops.name_scope("body"): actual_name_scope = ops.get_name_scope() expected_name_scope = "foo/while/body" assert actual_name_scope == expected_name_scope, ( "%s does not match %s" % (actual_name_scope, expected_name_scope)) return i return while_v2.while_loop(Cond, Body, [0.])
def test_composite_tensor(self): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) self.assertIsInstance(operator, composite_tensor.CompositeTensor) flat = nest.flatten(operator, expand_composites=True) unflat = nest.pack_sequence_as(operator, flat, expand_composites=True) self.assertIsInstance(unflat, type(operator)) # Input the operator to a `tf.function`. x = self.make_x(operator, adjoint=False) op_y = def_function.function(lambda op: op.matmul(x))(unflat) mat_y = math_ops.matmul(mat, x) if not use_placeholder: self.assertAllEqual(mat_y.shape, op_y.shape) # Test while_loop. def body(op): return type(op)(**op.parameters), op_out, = while_v2.while_loop(cond=lambda _: True, body=body, loop_vars=(operator, ), maximum_iterations=3) loop_y = op_out.matmul(x) op_y_, loop_y_, mat_y_ = sess.run([op_y, loop_y, mat_y]) self.assertAC(op_y_, mat_y_) self.assertAC(loop_y_, mat_y_) # Ensure that the `TypeSpec` can be encoded. struct_coder = nested_structure_coder.StructureCoder() struct_coder.encode_structure(operator._type_spec) # pylint: disable=protected-access
def model(x): return while_v2.while_loop(lambda v: v < 4., lambda v: v * v, [x], return_same_structure=False, name="while_1") # x**2