def while_loop(cond, body, loop_vars, **kwargs): """Like `tf.while_loop` but with structured `loop_vars`. Args: cond: as in `tf.while_loop`, but takes a single `loop_vars` argument. body: as in `tf.while_loop`, but takes and returns a single `loop_vars` tree which it is allowed to modify. loop_vars: as in `tf.while_loop`, but consists of a Namespace tree. **kwargs: passed onto `tf.while_loop`. Returns: A Namespace tree structure containing the final values of the loop variables. """ def _cond(*flat_vars): return cond(NS.UnflattenLike(loop_vars, flat_vars)) def _body(*flat_vars): return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars))) return NS.UnflattenLike( loop_vars, tf.while_loop(cond=_cond, body=_body, loop_vars=NS.Flatten(loop_vars), **kwargs))
def testCopy(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) after = NS.Copy(before) self.assertEqual(before, after) self.assertTrue( all(a is b for a, b in zip(NS.Flatten(after), NS.Flatten(before))))
def testFlattenUnflatten(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) flat = NS.Flatten(before) after = NS.UnflattenLike(before, flat) self.assertEqual(before, after)
def wrapped_branch(): tree = fn() liszt = NS.Flatten(tree) return liszt
def _body(*flat_vars): return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))