def broadcast_send_recv(device_id): c = constant_op.constant([2]) @def_function.function def send(): s0 = collective_ops.broadcast_send(c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1) with ops.control_dependencies([s0.op]): return array_ops.identity(c) @def_function.function def recv(): r0 = collective_ops.broadcast_recv(c.shape, c.dtype, group_size=2, group_key=1, instance_key=1) return r0 return control_flow_ops.switch_case(device_id, branch_fns={ 0: send, 1: recv })
def testSwitchCaseConstPropagation(self): self.skipTest("b/127846988") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.int32) def branch0(): return 5. def branch1(): return 15. # TODO(b/129021699): Wrapping this in a tf.function does not work. def branch2(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] output = control_flow_ops.switch_case(constant_op.constant(2), { 0: branch0, 1: branch1, 2: branch2, }) self.assertAllEqual( 7., sess.run(output, feed_dict={ x: [0., 1., 7.], p: 2, })) xla_context.Exit()
def build_functional_op(v): def branch0(): return array_ops.zeros([], v.dtype) def branch1(): return gen_resource_variable_ops.read_variable_op( v.handle, v.dtype) return control_flow_ops.switch_case(constant_op.constant(0), [branch0, branch1])
def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.switch_case( constant_op.constant(1), { 0: lambda: ta.write(0, 5.), 1: lambda: ta.write(0, 10.), 2: lambda: ta.write(0, 15.), }) return output.stack()
def run_inference(x): def do_inference(device, inference_fn, i): with ops.device(device): return inference_fn(x, i) branch_fns = { 0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)), 1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)), } branch_index = inference_iteration.assign_add(1, use_locking=True) % 2 return control_flow_ops.switch_case(branch_index, branch_fns)
def switch_case_test(branch_index): def f1(): return array_ops.constant(17) def f2(): return array_ops.constant(31) def f3(): return array_ops.constant(-1) return control_flow_ops.switch_case(branch_index, branch_fns={ 0: f1, 1: f2 }, default=f3)
def switch_case_test(): branch_index = array_ops.constant(0) def f1(): return array_ops.constant(17) def f2(): # Some operations that XLA cannot compile. image_ops.decode_image(io_ops.read_file('/tmp/bmp')) return array_ops.constant(31) # This tests that we do not try to compile all branches if the branch # index in trivially constant. return control_flow_ops.switch_case(branch_index, branch_fns={ 0: f1, 1: f2 }, default=f2)
def case_fn(x): branch_index = constant_op.constant(1) branches = [lambda: x, lambda: x + 1] case_out = control_flow_ops.switch_case(branch_index, branches) return case_out
def test_eager_switch_case_input(self): with context.eager_mode(): task = keras.Input(shape=(), dtype=dtypes.int32) control_flow_ops.switch_case( task[0], [lambda: constant_op.constant(1.0) for _ in range(10)])
def model(i, x): return control_flow_ops.switch_case(i, [ lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])