def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies pred = ops.convert_to_tensor(pred) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) verify_captures(_COND, [true_graph, false_graph]) return _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, building_gradient=False, name=scope)
def wrapped_cond(loop_counter, maximum_iterations_arg, *args): """Extra `cond` wrapper that can handle the extra counter loop_var.""" # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. pred = cond(*_pack_sequence_as(orig_loop_vars, args)) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) if maximum_iterations is None: return pred else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, pred)
def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None): if not feed_dict: feed_dict = {} with self.session(graph=ops.get_default_graph()) as sess: pred = array_ops.placeholder(dtypes.bool, name="pred") expected = control_flow_ops.cond( array_ops.squeeze_v2(pred), true_fn, false_fn, name="expected") actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") expected_grad = gradients_impl.gradients(expected, train_vals) actual_grad = gradients_impl.gradients(actual, train_vals) sess_run_args = {pred: True} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: [[True]]} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: False} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: [[False]]} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val)