Esempio n. 1
0
def cond(pred, fn1, fn2, prototype, **kwargs):
    """Like `tf.cond` but with structured collections of variables.

  Args:
    pred: boolean Tensor, as in `tf.cond`.
    fn1: a callable representing the `then` branch as in `tf.cond`, but
         may return an arbitrary Namespace tree.
    fn2: a callable representing the `else` branch as in `tf.cond`, but
         may return an arbitrary Namespace tree.
    prototype: an example Namespace tree to indicate the structure of the
               values returned from `fn1` and `fn2`.
    **kwargs: passed onto `tf.cond`.

  Returns:
    Like `tf.cond`, except structured like `prototype`.
  """
    def wrap_branch(fn):
        def wrapped_branch():
            tree = fn()
            liszt = NS.Flatten(tree)
            return liszt

        return wrapped_branch

    results = tf.cond(pred, wrap_branch(fn1), wrap_branch(fn2), **kwargs)
    # tf.cond unpacks singleton lists returned from fn1, fn2 -_-
    if not isinstance(results, (tuple, list)):
        results = [results]
    # need a prototype to unflatten because at this point neither fn1 nor fn2
    # have been called
    tree3 = NS.UnflattenLike(prototype, results)
    return tree3
Esempio n. 2
0
def while_loop(cond, body, loop_vars, **kwargs):
    """Like `tf.while_loop` but with structured `loop_vars`.

  Args:
    cond: as in `tf.while_loop`, but takes a single `loop_vars` argument.
    body: as in `tf.while_loop`, but takes and returns a single `loop_vars`
          tree which it is allowed to modify.
    loop_vars: as in `tf.while_loop`, but consists of a Namespace tree.
    **kwargs: passed onto `tf.while_loop`.

  Returns:
    A Namespace tree structure containing the final values of the loop
    variables.
  """
    def _cond(*flat_vars):
        return cond(NS.UnflattenLike(loop_vars, flat_vars))

    def _body(*flat_vars):
        return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))

    return NS.UnflattenLike(
        loop_vars,
        tf.while_loop(cond=_cond,
                      body=_body,
                      loop_vars=NS.Flatten(loop_vars),
                      **kwargs))
Esempio n. 3
0
 def testFlattenUnflatten(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     flat = NS.Flatten(before)
     after = NS.UnflattenLike(before, flat)
     self.assertEqual(before, after)
Esempio n. 4
0
 def _body(*flat_vars):
     return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))
Esempio n. 5
0
 def _cond(*flat_vars):
     return cond(NS.UnflattenLike(loop_vars, flat_vars))