def Run(branch, x, fetch_by_name, use_gpu=use_gpu): with ops.Graph().as_default() as g: @function.Defun(dtypes.float32) def two(x): return -1, x * 2 @function.Defun(dtypes.float32) def three(x): return 0, x * 3 @function.Defun(dtypes.float32) def four(x): return 1, x * 4 outputs = gen_functional_ops.case(branch, input=[x], Tout=[dtypes.int32, dtypes.float32], branches=[two, three, four], name="my_case") # `outputs` is the list of output tensors of the Case op. We # arbitrarily choose the 0th tensor to get the Case op and set the # lowering attribute on it. outputs[0].op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) outputs = array_ops.identity_n(outputs) with self.session(graph=g, use_gpu=use_gpu) as sess: return sess.run("my_case:1" if fetch_by_name else outputs[1])
def my_graph(pa, pb): with ipu.scopes.ipu_scope("/device:IPU:0"): @eager_function.defun def b0(x, y): return x + y @eager_function.defun def b1(x, y): return x - y @eager_function.defun def b2(x, y): return x * y v = variable_scope.get_variable('b0', dtype=dtypes.float32, initializer=[1., 5.]) branches = [ f.get_concrete_function(array_ops.zeros_like(pb), array_ops.zeros_like(v)) for f in [b0, b1, b2] ] c_out = gen_functional_ops.case(pa, input=[pb, v], Tout=[dtypes.float32], branches=branches) return [c_out[0]]
def my_graph(pa, pb, pc): with ipu.scopes.ipu_scope("/device:IPU:0"): @eager_function.defun def b0(x, y): return x + y @eager_function.defun def b1(x, y): return x - y @eager_function.defun def b2(x, y): return x * y branches = [ f.get_concrete_function(array_ops.zeros_like(pb), array_ops.zeros_like(pc)) for f in [b0, b1, b2] ] c_out = gen_functional_ops.case(pa, input=[pb, pc], Tout=[dtypes.float32], branches=branches) return [c_out[0]]
def _build_case(branch_index, branch_graphs, branch_inputs, name=None): """Creates an `Case` op from `branch_index`, branch graphs and inputs. Note that this modifies `branch_graphs` to make the inputs match, and to output all intermediates values so they're available for the gradient computation. `branch_graphs` need not have the same input types, but they must have the same outpute types. Args: branch_index: integer Tensor branch_graphs: List of FuncGraph branch_inputs: List of lists of Tensors to be passed to corresponding branch_graph as input. name: the name for the Case op. Returns: A list of Tensors which are the outputs of the Case op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_CASE, branch_graphs) _check_same_outputs(_CASE, branch_graphs) # Add inputs to branch_graphs to make them match. Note that this modifies the # graphs in `branch_graphs`. case_inputs = _make_inputs_match(branch_graphs, branch_inputs) # Create the Case op. with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): tensors = gen_functional_ops.case( branch_index, case_inputs, [t.dtype for t in branch_graphs[0].outputs], [util.create_new_tf_function(g) for g in branch_graphs], output_shapes=_get_output_shapes( *[g.outputs for g in branch_graphs]), name=name) case_op, tensors = _get_op_and_outputs(tensors) if case_op is not None: util.maybe_set_lowering_attr(case_op) util.maybe_propagate_compile_time_consts_in_xla(case_op) _set_read_only_resource_inputs_attr(case_op, branch_graphs) # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN # outputs, which if fetched will cause all ops in the taken branch to be run # (since it takes all merge ops as input). After lowering, each output # identity op will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
def f(branch, x): tmpl = array_ops.zeros_like(x) return array_ops.identity( gen_functional_ops.case(branch, input=[x], Tout=[dtypes.float32], branches=[ f.get_concrete_function(tmpl) for f in (two, three, four) ])[0])
def Run(branch, x): @function.Defun(dtypes.float32) def two(x): return -1, x * 2 @function.Defun(dtypes.float32) def three(x): return 0, x * 3 @function.Defun(dtypes.float32) def four(x): return 1, x * 4 outputs = gen_functional_ops.case(branch, input=[x], Tout=[dtypes.int32, dtypes.float32], branches=[two, three, four]) # `outputs` is the list of output tensors of the Case op. We # arbitrarily choose the 0th tensor to get the Case op and set the # lowering attribute on it. outputs[0].op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) outputs = array_ops.identity_n(outputs) return outputs[1]
def f(branch, x): tmpl = array_ops.zeros_like(x) return array_ops.identity(gen_functional_ops.case( branch, input=[x], Tout=[dtypes.float32], branches=[f.get_concrete_function(tmpl) for f in (two, three, four)])[0])