def smart_cond(pred, true_fn=None, false_fn=None, name=None): """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. If `pred` is a bool or has a constant value, we return either `true_fn()` or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. Arguments: pred: A scalar determining whether to return the result of `true_fn` or `false_fn`. true_fn: The callable to be performed if pred is true. false_fn: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: Tensors returned by the call to either `true_fn` or `false_fn`. Raises: TypeError: If `true_fn` or `false_fn` is not callable. """ if isinstance(pred, variables.Variable): return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) return control_flow_ops.smart_cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
def testSmartCondFalse(self): with ops.Graph().as_default(): with session.Session(): x = constant_op.constant(4) y = constant_op.constant(3) z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16), lambda: math_ops.multiply(y, 3)) self.assertEqual(z.eval(), 9)
def testSmartCondTrue(self): with ops.Graph().as_default(): with session.Session(): x = constant_op.constant(2) y = constant_op.constant(5) z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16), lambda: math_ops.multiply(y, 5)) self.assertEqual(z.eval(), 32)
def testSmartCondMissingArg2(self): with ops.Graph().as_default(): with session.Session(): x = constant_op.constant(1) with self.assertRaises(TypeError): control_flow_ops.smart_cond(True, lambda: x)