def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
        with ops.Graph().as_default() as g:
          @function.Defun(dtypes.float32)
          def two(x):
            return -1, x * 2

          @function.Defun(dtypes.float32)
          def three(x):
            return 0, x * 3

          @function.Defun(dtypes.float32)
          def four(x):
            return 1, x * 4

          outputs = gen_functional_ops.case(branch, input=[x],
                                            Tout=[dtypes.int32, dtypes.float32],
                                            branches=[two, three, four],
                                            name="my_case")

          # `outputs` is the list of output tensors of the Case op. We
          # arbitrarily choose the 0th tensor to get the Case op and set the
          # lowering attribute on it.
          outputs[0].op._set_attr("_lower_using_switch_merge",
                                  attr_value_pb2.AttrValue(b=True))
          outputs = array_ops.identity_n(outputs)
        with self.session(graph=g, use_gpu=use_gpu) as sess:
          return sess.run("my_case:1" if fetch_by_name else outputs[1])
Ejemplo n.º 2
0
      def my_graph(pa, pb):
        with ipu.scopes.ipu_scope("/device:IPU:0"):

          @eager_function.defun
          def b0(x, y):
            return x + y

          @eager_function.defun
          def b1(x, y):
            return x - y

          @eager_function.defun
          def b2(x, y):
            return x * y

          v = variable_scope.get_variable('b0',
                                          dtype=dtypes.float32,
                                          initializer=[1., 5.])

          branches = [
              f.get_concrete_function(array_ops.zeros_like(pb),
                                      array_ops.zeros_like(v))
              for f in [b0, b1, b2]
          ]

          c_out = gen_functional_ops.case(pa,
                                          input=[pb, v],
                                          Tout=[dtypes.float32],
                                          branches=branches)

          return [c_out[0]]
Ejemplo n.º 3
0
      def my_graph(pa, pb, pc):
        with ipu.scopes.ipu_scope("/device:IPU:0"):

          @eager_function.defun
          def b0(x, y):
            return x + y

          @eager_function.defun
          def b1(x, y):
            return x - y

          @eager_function.defun
          def b2(x, y):
            return x * y

          branches = [
              f.get_concrete_function(array_ops.zeros_like(pb),
                                      array_ops.zeros_like(pc))
              for f in [b0, b1, b2]
          ]

          c_out = gen_functional_ops.case(pa,
                                          input=[pb, pc],
                                          Tout=[dtypes.float32],
                                          branches=branches)

          return [c_out[0]]
Ejemplo n.º 4
0
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
    """Creates an `Case` op from `branch_index`, branch graphs and inputs.

  Note that this modifies `branch_graphs` to make the inputs match, and to
  output all intermediates values so they're available for the gradient
  computation.

  `branch_graphs` need not have the same input types, but they must
  have the same outpute types.

  Args:
    branch_index: integer Tensor
    branch_graphs: List of FuncGraph
    branch_inputs: List of lists of Tensors to be passed to corresponding
      branch_graph as input.
    name: the name for the Case op.

  Returns:
    A list of Tensors which are the outputs of the Case op. Does not include
    added intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
    _check_same_outputs(_CASE, branch_graphs)

    # Add inputs to branch_graphs to make them match. Note that this modifies the
    # graphs in `branch_graphs`.
    case_inputs = _make_inputs_match(branch_graphs, branch_inputs)

    # Create the Case op.
    with ops.control_dependencies(
            sum((list(bg.control_captures) for bg in branch_graphs), [])):
        tensors = gen_functional_ops.case(
            branch_index,
            case_inputs, [t.dtype for t in branch_graphs[0].outputs],
            [util.create_new_tf_function(g) for g in branch_graphs],
            output_shapes=_get_output_shapes(
                *[g.outputs for g in branch_graphs]),
            name=name)

    case_op, tensors = _get_op_and_outputs(tensors)

    if case_op is not None:
        util.maybe_set_lowering_attr(case_op)
        util.maybe_propagate_compile_time_consts_in_xla(case_op)
        _set_read_only_resource_inputs_attr(case_op, branch_graphs)
        # Prevent fetching since the variant outputs can't be fetched directly.
        case_op.graph.prevent_fetching(case_op)

    # Return identities for each output of the Case op, rather than the output of
    # the Case op directly. This makes pruning work if the output of switch_case()
    # is fetched: the lowering pass converts the Case outputs into IdentityN
    # outputs, which if fetched will cause all ops in the taken branch to be run
    # (since it takes all merge ops as input). After lowering, each output
    # identity op will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
Ejemplo n.º 5
0
 def f(branch, x):
     tmpl = array_ops.zeros_like(x)
     return array_ops.identity(
         gen_functional_ops.case(branch,
                                 input=[x],
                                 Tout=[dtypes.float32],
                                 branches=[
                                     f.get_concrete_function(tmpl)
                                     for f in (two, three, four)
                                 ])[0])
      def Run(branch, x):
        @function.Defun(dtypes.float32)
        def two(x):
          return -1, x * 2

        @function.Defun(dtypes.float32)
        def three(x):
          return 0, x * 3

        @function.Defun(dtypes.float32)
        def four(x):
          return 1, x * 4

        outputs = gen_functional_ops.case(branch, input=[x],
                                          Tout=[dtypes.int32, dtypes.float32],
                                          branches=[two, three, four])

        # `outputs` is the list of output tensors of the Case op. We
        # arbitrarily choose the 0th tensor to get the Case op and set the
        # lowering attribute on it.
        outputs[0].op._set_attr("_lower_using_switch_merge",
                                attr_value_pb2.AttrValue(b=True))
        outputs = array_ops.identity_n(outputs)
        return outputs[1]
Ejemplo n.º 7
0
 def f(branch, x):
   tmpl = array_ops.zeros_like(x)
   return array_ops.identity(gen_functional_ops.case(
       branch, input=[x], Tout=[dtypes.float32],
       branches=[f.get_concrete_function(tmpl)
                 for f in (two, three, four)])[0])