コード例 #1
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies
        pred = ops.convert_to_tensor(pred)
        if (tensor_util.is_tensor(pred)
                and (pred.shape.dims is None or pred.shape.dims)):
            pred = array_ops.squeeze_v2(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)

        verify_captures(_COND, [true_graph, false_graph])
        return _build_cond(pred,
                           true_graph,
                           false_graph,
                           true_graph.external_captures,
                           false_graph.external_captures,
                           building_gradient=False,
                           name=scope)
コード例 #2
0
ファイル: cond_v2.py プロジェクト: Harryi0/tinyML
def indexed_case(branch_index, branch_fns, name="indexed_case"):
  """Like conv_v2, except emits a Case op instead of an If."""
  if isinstance(branch_index, int):
    raise TypeError("branch_index must not be a Python int", branch_index)

  with ops.name_scope(name) as scope:
    branch_names = [
        util.unique_fn_name(scope, "branch{}".format(b))
        for b in range(len(branch_fns))
    ]

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
    branch_index = ops.convert_to_tensor(branch_index, name="branch_index")

    branch_graphs = []
    for branch_name, branch_fn in zip(branch_names, branch_fns):
      branch_graphs.append(
          func_graph_module.func_graph_from_py_func(
              branch_name,
              branch_fn,
              [],
              {},
              func_graph=util.CondBranchFuncGraph(
                  branch_name,
                  collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
              add_control_dependencies=add_control_dependencies,
              op_return_value=branch_index))

    verify_captures(_CASE, branch_graphs)
    return _build_case(
        branch_index,
        branch_graphs, [g.external_captures for g in branch_graphs],
        name=scope)
コード例 #3
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies
        pred = ops.convert_to_tensor(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)

        outputs = _build_cond(pred,
                              true_graph,
                              false_graph,
                              true_graph.external_captures,
                              false_graph.external_captures,
                              name=scope)

        return func_graph_module.pack_sequence_as(
            true_graph.structured_outputs, outputs)
コード例 #4
0
ファイル: cond_v2.py プロジェクト: fraudies/tensorflow
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      # Find the outer most graph for uniquing function names.
      # TODO(jpienaar): Make this work in eager mode.
      graph = ops.get_default_graph()
      while isinstance(graph, function.FuncGraph):
        graph = graph.outer_graph

      true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
      false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)

    verify_captures(true_graph, false_graph)
    return _build_cond(pred, true_graph, false_graph,
                       true_graph.external_captures,
                       false_graph.external_captures,
                       name=scope)
コード例 #5
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    util.maybe_set_lowering_attr(tensors[0].op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
コード例 #6
0
def _create_grad_func(func_graph, grads, name):
  """Returns the FuncGraph representation of _grad_fn."""
  return func_graph_module.func_graph_from_py_func(
      name,
      lambda: _grad_fn(func_graph, grads), [], {},
      func_graph=util.CondBranchFuncGraph(name, read_only_collections=False))
コード例 #7
0
ファイル: cond_v2.py プロジェクト: w62651515/tensorflow
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            # Find the outer most graph for uniquing function names.
            # TODO(jpienaar): Make this work in eager mode.
            graph = ops.get_default_graph()
            while isinstance(graph, function.FuncGraph):
                graph = graph.outer_graph

            true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
            false_name = graph.unique_name(
                ("%sfalse" % scope).replace("/", "_"))

        true_graph = function.func_graph_from_py_func(
            true_name,
            true_fn, [], {},
            func_graph=util.CondBranchFuncGraph(true_name))
        false_graph = function.func_graph_from_py_func(
            false_name,
            false_fn, [], {},
            func_graph=util.CondBranchFuncGraph(false_name))
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.external_captures,
                                         false_graph.external_captures)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(  # pylint: disable=protected-access
            pred,
            cond_inputs, [t.dtype for t in true_graph.outputs],
            util.create_new_tf_function(true_graph),
            util.create_new_tf_function(false_graph),
            output_shapes=_get_output_shapes(true_graph.outputs,
                                             false_graph.outputs),
            name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            # pylint: disable=protected-access
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))
            # pylint: enable=protected-access

        result = tuple(tensors[:num_cond_outputs])
        if len(result) == 1:
            return result[0]
        else:
            return result
コード例 #8
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = util.in_defun()
        pred = ops.convert_to_tensor(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn, [], {},
            func_graph=util.CondBranchFuncGraph(true_name,
                                                read_only_collections=False),
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn, [], {},
            func_graph=util.CondBranchFuncGraph(false_name,
                                                read_only_collections=False),
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.external_captures,
                                         false_graph.external_captures)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(  # pylint: disable=protected-access
            pred,
            cond_inputs, [t.dtype for t in true_graph.outputs],
            util.create_new_tf_function(true_graph),
            util.create_new_tf_function(false_graph),
            output_shapes=_get_output_shapes(true_graph.outputs,
                                             false_graph.outputs),
            name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            # pylint: disable=protected-access
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))
            # pylint: enable=protected-access

        # Return identities for each output of the If op, rather than the output of
        # the If op directly. This makes pruning work if the output of cond() is
        # fetched: the lowering pass converts the If outputs into IdentityN outputs,
        # which if fetched will cause all ops in the taken branch to be run (since
        # it takes all merge ops as input). After lowering, each output identity op
        # will end up with only the appropriate merge op as input.
        # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
        # correct output structure
        tensors = tuple(array_ops.identity(t) for t in tensors)

        result = tuple(tensors[:num_cond_outputs])
        if len(result) == 1:
            return result[0]
        else:
            return result