Exemple #1
0
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
    """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

  If `pred` is a bool or has a constant value, we return either `true_fn()`
  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

  Arguments:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`.

  Raises:
    TypeError: If `true_fn` or `false_fn` is not callable.
  """
    if isinstance(pred, variables.Variable):
        return control_flow_ops.cond(pred,
                                     true_fn=true_fn,
                                     false_fn=false_fn,
                                     name=name)
    return control_flow_ops.smart_cond(pred,
                                       true_fn=true_fn,
                                       false_fn=false_fn,
                                       name=name)
Exemple #2
0
 def testSmartCondFalse(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(4)
       y = constant_op.constant(3)
       z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16),
                                       lambda: math_ops.multiply(y, 3))
       self.assertEqual(z.eval(), 9)
Exemple #3
0
 def testSmartCondTrue(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(2)
       y = constant_op.constant(5)
       z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16),
                                       lambda: math_ops.multiply(y, 5))
       self.assertEqual(z.eval(), 32)
Exemple #4
0
 def testSmartCondMissingArg2(self):
     with ops.Graph().as_default():
         with session.Session():
             x = constant_op.constant(1)
             with self.assertRaises(TypeError):
                 control_flow_ops.smart_cond(True, lambda: x)
 def testSmartCondMissingArg2(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(1)
       with self.assertRaises(TypeError):
         control_flow_ops.smart_cond(True, lambda: x)