示例#1
0
def while_loop(cond, body, loop_vars, **kwargs):
    """Like `tf.while_loop` but with structured `loop_vars`.

  Args:
    cond: as in `tf.while_loop`, but takes a single `loop_vars` argument.
    body: as in `tf.while_loop`, but takes and returns a single `loop_vars`
          tree which it is allowed to modify.
    loop_vars: as in `tf.while_loop`, but consists of a Namespace tree.
    **kwargs: passed onto `tf.while_loop`.

  Returns:
    A Namespace tree structure containing the final values of the loop
    variables.
  """
    def _cond(*flat_vars):
        return cond(NS.UnflattenLike(loop_vars, flat_vars))

    def _body(*flat_vars):
        return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))

    return NS.UnflattenLike(
        loop_vars,
        tf.while_loop(cond=_cond,
                      body=_body,
                      loop_vars=NS.Flatten(loop_vars),
                      **kwargs))
示例#2
0
 def testCopy(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     after = NS.Copy(before)
     self.assertEqual(before, after)
     self.assertTrue(
         all(a is b for a, b in zip(NS.Flatten(after), NS.Flatten(before))))
示例#3
0
 def testFlattenUnflatten(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     flat = NS.Flatten(before)
     after = NS.UnflattenLike(before, flat)
     self.assertEqual(before, after)
示例#4
0
 def wrapped_branch():
     tree = fn()
     liszt = NS.Flatten(tree)
     return liszt
示例#5
0
 def _body(*flat_vars):
     return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))