Exemplo n.º 1
0
    def test_import_tf_comp_with_while_loop(self):
        @computations.tf_computation(tf.float32)
        def comp(x):
            # An example of a loop with variables that computes 2^x by counting from
            # x down to 0, and doubling the result in each iteration.
            a = tf.Variable(0.0)
            b = tf.Variable(1.0)
            with tf.control_dependencies([a.initializer, b.initializer]):
                with tf.control_dependencies([a.assign(x)]):
                    cond_fn = lambda a, b: a > 0
                    body_fn = lambda a, b: (a - 1.0, b * 2.0)
                    return tf.while_loop(cond_fn, body_fn, (a, b))[1]

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)

        # Not checking the full MLIR in the long generated body, just that we can
        # successfully ingest TF code containing a while loop here, end-to-end. We
        # need some form of looping support in lieu of `tf.data.Dataset.reduce()`.
        self._assert_mlir_contains_pattern(
            mlir, ['func @fn(%arg0: tensor<f32>) -> tensor<f32>'])

        result = runtime.compile_and_run_on_args(module,
                                                 backend_info.VULKAN_SPIRV,
                                                 np.float32(5.0))
        self.assertEqual(result, 32.0)
Exemplo n.º 2
0
    def test_import_tf_comp_with_variable_assign_add_one(self):
        @computations.tf_computation(tf.float32)
        def comp(x):
            v = tf.Variable(1.0)
            with tf.control_dependencies([v.initializer]):
                with tf.control_dependencies([v.assign_add(x)]):
                    return tf.identity(v)

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)

        # TODO(b/153499219): Introduce the concept of local variables, so that code
        # like what's in this section below can be dramatically simplified.
        self._assert_mlir_contains_pattern(mlir, [
            'flow.variable SOMETHING mutable dense<1.000000e+00> : tensor<f32>',
            'func @fn(%arg0: tensor<f32>) -> tensor<f32> SOMETHING {',
            '  %0 = flow.variable.address',
            '  %1 = mhlo.constant dense<1.000000e+00>',
            '  flow.variable.store.indirect %1, %0',
            '  %2 = flow.variable.load.indirect %0',
            '  %3 = mhlo.add %2, %arg0',
            '  flow.variable.store.indirect %3, %0',
            '  %4 = flow.variable.load.indirect %0',
            '  return %4',
            '}',
        ])
        result = runtime.compile_and_run_on_args(module,
                                                 backend_info.VULKAN_SPIRV,
                                                 np.float32(5.0))
        self.assertEqual(result, 6.0)
Exemplo n.º 3
0
    def test_import_tf_comp_with_while_loop(self):
        @computations.tf_computation(tf.float32)
        def comp(x):
            # An example of a loop with variables that computes 2^x by counting from
            # x down to 0, and doubling the result in each iteration.
            a = tf.Variable(0.0)
            b = tf.Variable(1.0)
            with tf.control_dependencies([a.initializer, b.initializer]):
                with tf.control_dependencies([a.assign(x)]):
                    cond_fn = lambda a, b: a > 0
                    body_fn = lambda a, b: (a - 1.0, b * 2.0)
                    return tf.while_loop(cond_fn, body_fn, (a, b))[1]

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)

        # Not checking the full MLIR in the long generated body, just that we can
        # successfully ingest TF code containing a while loop here, end-to-end. We
        # need some form of looping support in lieu of `tf.data.Dataset.reduce()`.
        self._assert_mlir_contains_pattern(
            mlir, ['func @fn(%arg0: tensor<f32>) -> tensor<f32>'])

        # TODO(b/153499219): Switch the backend to VULKAN_SPIRV after fixing these
        # compilation errors on VULKAN:
        # * [ERROR]: SPIRV type conversion failed: 'memref<i1>'
        # * [ERROR]: failed to legalize operation 'iree.placeholder'
        # * [ERROR]: failed to run translation of source executable to target
        #            executable for backend vulkan.
        result = runtime.compile_and_run_on_args(module, backend_info.VMLA,
                                                 np.float32(5.0))
        self.assertEqual(result, 32.0)
Exemplo n.º 4
0
    def test_import_tf_comp_with_one_constant(self):
        @computations.tf_computation
        def comp():
            return 99.0

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)
        self._assert_mlir_contains_pattern(mlir, [
            'func @fn() -> tensor<f32> SOMETHING {',
            '  %0 = mhlo.constant dense<9.900000e+01>',
            '  return %0',
            '}',
        ])
        result = runtime.compile_and_run_on_args(module,
                                                 backend_info.VULKAN_SPIRV)
        self.assertEqual(result, 99.0)
Exemplo n.º 5
0
    def test_import_tf_comp_with_add_one(self):
        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)
        self._assert_mlir_contains_pattern(mlir, [
            'func @fn(%arg0: tensor<f32>) -> tensor<f32> SOMETHING {',
            '  %0 = mhlo.constant dense<1.000000e+00>',
            '  %1 = mhlo.add %arg0, %0',
            '  return %1',
            '}',
        ])
        result = runtime.compile_and_run_on_args(module,
                                                 backend_info.VULKAN_SPIRV,
                                                 np.float32(5.0))
        self.assertEqual(result, 6.0)
Exemplo n.º 6
0
    def test_import_tf_comp_with_one_variable_constant(self):
        @computations.tf_computation
        def comp():
            return tf.Variable(99.0)

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)
        self._assert_mlir_contains_pattern(mlir, [
            'func @fn() -> tensor<f32> SOMETHING {',
            '  %0 = flow.variable.address',
            '  %1 = mhlo.constant dense<9.900000e+01>',
            '  flow.variable.store.indirect %1, %0',
            '  %2 = flow.variable.load.indirect %0',
            '  return %2',
            '}',
        ])
        result = runtime.compile_and_run_on_args(module,
                                                 backend_info.VULKAN_SPIRV)
        self.assertEqual(result, 99.0)
Exemplo n.º 7
0
    def test_import_tf_comp_with_variable_add_one(self):
        @computations.tf_computation(tf.float32)
        def comp(x):
            v = tf.Variable(1.0)
            with tf.control_dependencies([v.initializer]):
                return tf.add(v, x)

        module, mlir = self._import_compile_and_return_module_and_mlir(comp)
        self._assert_mlir_contains_pattern(mlir, [
            'func @fn(%arg0: tensor<f32>) -> tensor<f32> SOMETHING {',
            '  %0 = flow.variable.address',
            '  %1 = mhlo.constant dense<1.000000e+00>',
            '  flow.variable.store.indirect %1, %0',
            '  %2 = flow.variable.load.indirect %0',
            '  %3 = mhlo.add %2, %arg0',
            '  return %3',
            '}',
        ])
        result = runtime.compile_and_run_on_args(module,
                                                 backend_info.VULKAN_SPIRV,
                                                 np.float32(5.0))
        self.assertEqual(result, 6.0)