Пример #1
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies
        pred = ops.convert_to_tensor(pred)
        if (tensor_util.is_tensor(pred)
                and (pred.shape.dims is None or pred.shape.dims)):
            pred = array_ops.squeeze_v2(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)

        verify_captures(_COND, [true_graph, false_graph])
        return _build_cond(pred,
                           true_graph,
                           false_graph,
                           true_graph.external_captures,
                           false_graph.external_captures,
                           building_gradient=False,
                           name=scope)
Пример #2
0
    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)
Пример #3
0
  def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None):
    if not feed_dict:
      feed_dict = {}
    with self.session(graph=ops.get_default_graph()) as sess:
      pred = array_ops.placeholder(dtypes.bool, name="pred")

      expected = control_flow_ops.cond(
          array_ops.squeeze_v2(pred), true_fn, false_fn, name="expected")
      actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")

      expected_grad = gradients_impl.gradients(expected, train_vals)
      actual_grad = gradients_impl.gradients(actual, train_vals)

      sess_run_args = {pred: True}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: [[True]]}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: False}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: [[False]]}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)