Beispiel #1
0
        def broadcast_send_recv(device_id):
            c = constant_op.constant([2])

            @def_function.function
            def send():
                s0 = collective_ops.broadcast_send(c * 3,
                                                   c.shape,
                                                   c.dtype,
                                                   group_size=2,
                                                   group_key=1,
                                                   instance_key=1)
                with ops.control_dependencies([s0.op]):
                    return array_ops.identity(c)

            @def_function.function
            def recv():
                r0 = collective_ops.broadcast_recv(c.shape,
                                                   c.dtype,
                                                   group_size=2,
                                                   group_key=1,
                                                   instance_key=1)
                return r0

            return control_flow_ops.switch_case(device_id,
                                                branch_fns={
                                                    0: send,
                                                    1: recv
                                                })
    def testSwitchCaseConstPropagation(self):
        self.skipTest("b/127846988")
        with self.session() as sess, self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            x = array_ops.placeholder(dtypes.float32)
            p = array_ops.placeholder(dtypes.int32)

            def branch0():
                return 5.

            def branch1():
                return 15.

            # TODO(b/129021699): Wrapping this in a tf.function does not work.
            def branch2():
                # This emits a StridedSlice op which expects the index to be a
                # compile-time const.
                return x[p]

            output = control_flow_ops.switch_case(constant_op.constant(2), {
                0: branch0,
                1: branch1,
                2: branch2,
            })

            self.assertAllEqual(
                7., sess.run(output, feed_dict={
                    x: [0., 1., 7.],
                    p: 2,
                }))

            xla_context.Exit()
Beispiel #3
0
        def build_functional_op(v):
            def branch0():
                return array_ops.zeros([], v.dtype)

            def branch1():
                return gen_resource_variable_ops.read_variable_op(
                    v.handle, v.dtype)

            return control_flow_ops.switch_case(constant_op.constant(0),
                                                [branch0, branch1])
            def f():
                ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
                output = control_flow_ops.switch_case(
                    constant_op.constant(1), {
                        0: lambda: ta.write(0, 5.),
                        1: lambda: ta.write(0, 10.),
                        2: lambda: ta.write(0, 15.),
                    })

                return output.stack()
Beispiel #5
0
    def run_inference(x):

      def do_inference(device, inference_fn, i):
        with ops.device(device):
          return inference_fn(x, i)

      branch_fns = {
          0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
          1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
      }
      branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
      return control_flow_ops.switch_case(branch_index, branch_fns)
Beispiel #6
0
        def switch_case_test(branch_index):
            def f1():
                return array_ops.constant(17)

            def f2():
                return array_ops.constant(31)

            def f3():
                return array_ops.constant(-1)

            return control_flow_ops.switch_case(branch_index,
                                                branch_fns={
                                                    0: f1,
                                                    1: f2
                                                },
                                                default=f3)
Beispiel #7
0
        def switch_case_test():
            branch_index = array_ops.constant(0)

            def f1():
                return array_ops.constant(17)

            def f2():
                # Some operations that XLA cannot compile.
                image_ops.decode_image(io_ops.read_file('/tmp/bmp'))
                return array_ops.constant(31)

            # This tests that we do not try to compile all branches if the branch
            # index in trivially constant.
            return control_flow_ops.switch_case(branch_index,
                                                branch_fns={
                                                    0: f1,
                                                    1: f2
                                                },
                                                default=f2)
Beispiel #8
0
 def case_fn(x):
   branch_index = constant_op.constant(1)
   branches = [lambda: x, lambda: x + 1]
   case_out = control_flow_ops.switch_case(branch_index, branches)
   return case_out
Beispiel #9
0
 def test_eager_switch_case_input(self):
     with context.eager_mode():
         task = keras.Input(shape=(), dtype=dtypes.int32)
         control_flow_ops.switch_case(
             task[0],
             [lambda: constant_op.constant(1.0) for _ in range(10)])
Beispiel #10
0
 def model(i, x):
   return control_flow_ops.switch_case(i, [
       lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])