def test_import_tf_comp_with_while_loop(self): @computations.tf_computation(tf.float32) def comp(x): # An example of a loop with variables that computes 2^x by counting from # x down to 0, and doubling the result in each iteration. a = tf.Variable(0.0) b = tf.Variable(1.0) with tf.control_dependencies([a.initializer, b.initializer]): with tf.control_dependencies([a.assign(x)]): cond_fn = lambda a, b: a > 0 body_fn = lambda a, b: (a - 1.0, b * 2.0) return tf.while_loop(cond_fn, body_fn, (a, b))[1] module, mlir = self._import_compile_and_return_module_and_mlir(comp) # Not checking the full MLIR in the long generated body, just that we can # successfully ingest TF code containing a while loop here, end-to-end. We # need some form of looping support in lieu of `tf.data.Dataset.reduce()`. self._assert_mlir_contains_pattern( mlir, ['func @fn(%arg0: tensor<f32>) -> tensor<f32>']) result = runtime.compile_and_run_on_args(module, backend_info.VULKAN_SPIRV, np.float32(5.0)) self.assertEqual(result, 32.0)
def test_import_tf_comp_with_variable_assign_add_one(self): @computations.tf_computation(tf.float32) def comp(x): v = tf.Variable(1.0) with tf.control_dependencies([v.initializer]): with tf.control_dependencies([v.assign_add(x)]): return tf.identity(v) module, mlir = self._import_compile_and_return_module_and_mlir(comp) # TODO(b/153499219): Introduce the concept of local variables, so that code # like what's in this section below can be dramatically simplified. self._assert_mlir_contains_pattern(mlir, [ 'flow.variable SOMETHING mutable dense<1.000000e+00> : tensor<f32>', 'func @fn(%arg0: tensor<f32>) -> tensor<f32> SOMETHING {', ' %0 = flow.variable.address', ' %1 = mhlo.constant dense<1.000000e+00>', ' flow.variable.store.indirect %1, %0', ' %2 = flow.variable.load.indirect %0', ' %3 = mhlo.add %2, %arg0', ' flow.variable.store.indirect %3, %0', ' %4 = flow.variable.load.indirect %0', ' return %4', '}', ]) result = runtime.compile_and_run_on_args(module, backend_info.VULKAN_SPIRV, np.float32(5.0)) self.assertEqual(result, 6.0)
def test_import_tf_comp_with_while_loop(self): @computations.tf_computation(tf.float32) def comp(x): # An example of a loop with variables that computes 2^x by counting from # x down to 0, and doubling the result in each iteration. a = tf.Variable(0.0) b = tf.Variable(1.0) with tf.control_dependencies([a.initializer, b.initializer]): with tf.control_dependencies([a.assign(x)]): cond_fn = lambda a, b: a > 0 body_fn = lambda a, b: (a - 1.0, b * 2.0) return tf.while_loop(cond_fn, body_fn, (a, b))[1] module, mlir = self._import_compile_and_return_module_and_mlir(comp) # Not checking the full MLIR in the long generated body, just that we can # successfully ingest TF code containing a while loop here, end-to-end. We # need some form of looping support in lieu of `tf.data.Dataset.reduce()`. self._assert_mlir_contains_pattern( mlir, ['func @fn(%arg0: tensor<f32>) -> tensor<f32>']) # TODO(b/153499219): Switch the backend to VULKAN_SPIRV after fixing these # compilation errors on VULKAN: # * [ERROR]: SPIRV type conversion failed: 'memref<i1>' # * [ERROR]: failed to legalize operation 'iree.placeholder' # * [ERROR]: failed to run translation of source executable to target # executable for backend vulkan. result = runtime.compile_and_run_on_args(module, backend_info.VMLA, np.float32(5.0)) self.assertEqual(result, 32.0)
def test_import_tf_comp_with_one_constant(self): @computations.tf_computation def comp(): return 99.0 module, mlir = self._import_compile_and_return_module_and_mlir(comp) self._assert_mlir_contains_pattern(mlir, [ 'func @fn() -> tensor<f32> SOMETHING {', ' %0 = mhlo.constant dense<9.900000e+01>', ' return %0', '}', ]) result = runtime.compile_and_run_on_args(module, backend_info.VULKAN_SPIRV) self.assertEqual(result, 99.0)
def test_import_tf_comp_with_add_one(self): @computations.tf_computation(tf.float32) def comp(x): return x + 1.0 module, mlir = self._import_compile_and_return_module_and_mlir(comp) self._assert_mlir_contains_pattern(mlir, [ 'func @fn(%arg0: tensor<f32>) -> tensor<f32> SOMETHING {', ' %0 = mhlo.constant dense<1.000000e+00>', ' %1 = mhlo.add %arg0, %0', ' return %1', '}', ]) result = runtime.compile_and_run_on_args(module, backend_info.VULKAN_SPIRV, np.float32(5.0)) self.assertEqual(result, 6.0)
def test_import_tf_comp_with_one_variable_constant(self): @computations.tf_computation def comp(): return tf.Variable(99.0) module, mlir = self._import_compile_and_return_module_and_mlir(comp) self._assert_mlir_contains_pattern(mlir, [ 'func @fn() -> tensor<f32> SOMETHING {', ' %0 = flow.variable.address', ' %1 = mhlo.constant dense<9.900000e+01>', ' flow.variable.store.indirect %1, %0', ' %2 = flow.variable.load.indirect %0', ' return %2', '}', ]) result = runtime.compile_and_run_on_args(module, backend_info.VULKAN_SPIRV) self.assertEqual(result, 99.0)
def test_import_tf_comp_with_variable_add_one(self): @computations.tf_computation(tf.float32) def comp(x): v = tf.Variable(1.0) with tf.control_dependencies([v.initializer]): return tf.add(v, x) module, mlir = self._import_compile_and_return_module_and_mlir(comp) self._assert_mlir_contains_pattern(mlir, [ 'func @fn(%arg0: tensor<f32>) -> tensor<f32> SOMETHING {', ' %0 = flow.variable.address', ' %1 = mhlo.constant dense<1.000000e+00>', ' flow.variable.store.indirect %1, %0', ' %2 = flow.variable.load.indirect %0', ' %3 = mhlo.add %2, %arg0', ' return %3', '}', ]) result = runtime.compile_and_run_on_args(module, backend_info.VULKAN_SPIRV, np.float32(5.0)) self.assertEqual(result, 6.0)