def testFreezeModel(self): if not test.is_built_with_xla(): self.skipTest('Skipping test because XLA is not compiled in.') variables_to_feed = 'all' func = 'func2' saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') dummy_model = self.AOTCompileDummyModel() func = getattr(dummy_model, func) with self.cached_session(): self.evaluate(dummy_model.var.initializer) self.evaluate(dummy_model.write_var.initializer) save.save(dummy_model, saved_model_dir, signatures={'func': func}) self.parser = saved_model_cli.create_parser() output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out') args = [ # Use the default seving signature_key. 'freeze_model', '--dir', saved_model_dir, '--tag_set', 'serve', '--signature_def_key', 'func', '--output_prefix', output_prefix, '--variables_to_feed', variables_to_feed ] args = self.parser.parse_args(args) with test.mock.patch.object(logging, 'warn'): saved_model_cli.freeze_model(args) self.assertTrue( file_io.file_exists(os.path.join(output_prefix, 'frozen_graph.pb'))) self.assertTrue( file_io.file_exists(os.path.join(output_prefix, 'config.pbtxt')))
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed, func, target_triple): if not test.is_built_with_xla(): self.skipTest('Skipping test because XLA is not compiled in.') saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') dummy_model = self.AOTCompileDummyModel() func = getattr(dummy_model, func) with self.cached_session(): self.evaluate(dummy_model.var.initializer) self.evaluate(dummy_model.write_var.initializer) save.save(dummy_model, saved_model_dir, signatures={'func': func}) self.parser = saved_model_cli.create_parser() output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out') args = [ # Use the default seving signature_key. 'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve', '--signature_def_key', 'func', '--output_prefix', output_prefix, '--variables_to_feed', variables_to_feed, '--cpp_class', 'Generated' ] if target_triple: args.extend(['--target_triple', target_triple]) args = self.parser.parse_args(args) with test.mock.patch.object(logging, 'warn') as captured_warn: saved_model_cli.aot_compile_cpu(args) self.assertRegex( str(captured_warn.call_args), 'Signature input key \'y\'.*has been pruned while freezing the graph.' ) self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix))) self.assertTrue( file_io.file_exists('{}_metadata.o'.format(output_prefix))) self.assertTrue( file_io.file_exists('{}_makefile.inc'.format(output_prefix))) header_contents = file_io.read_file_to_string( '{}.h'.format(output_prefix)) self.assertIn('class Generated', header_contents) self.assertIn('arg_feed_x_data', header_contents) self.assertIn('result_fetch_res_data', header_contents) # arg_y got filtered out as it's not used by the output. self.assertNotIn('arg_feed_y_data', header_contents) if variables_to_feed: # Read-only-variables' setters preserve constness. self.assertIn('set_var_param_my_var_data(const float', header_contents) self.assertNotIn('set_var_param_my_var_data(float', header_contents) if func == dummy_model.func_write: # Writeable variables setters do not preserve constness. self.assertIn('set_var_param_write_var_data(float', header_contents) self.assertNotIn('set_var_param_write_var_data(const float', header_contents) makefile_contents = file_io.read_file_to_string( '{}_makefile.inc'.format(output_prefix)) self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
def testAOTCompileCPUWrongSignatureDefKey(self): if not test.is_built_with_xla(): self.skipTest('Skipping test because XLA is not compiled in.') self.parser = saved_model_cli.create_parser() base_path = test.test_src_dir_path(SAVED_MODEL_PATH) output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir') args = self.parser.parse_args([ 'aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve', '--output_prefix', output_dir, '--cpp_class', 'Compiled', '--signature_def_key', 'MISSING' ]) with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'): saved_model_cli.aot_compile_cpu(args)
def testAOTCompileCPUFreezesAndCompiles(self): if not test.is_built_with_xla(): self.skipTest('Skipping test because XLA is not compiled in.') class DummyModel(tracking.AutoTrackable): """Model compatible with XLA compilation.""" def __init__(self): self.var = variables.Variable(1.0, name='my_var') @def_function.function(input_signature=[ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) ]) def func2(self, x): return {'res': x + self.var} saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') dummy_model = DummyModel() with self.cached_session(): self.evaluate(dummy_model.var.initializer) save.save(dummy_model, saved_model_dir) self.parser = saved_model_cli.create_parser() output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out') args = self.parser.parse_args([ 'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve', '--output_prefix', output_prefix, '--cpp_class', 'Generated' ]) # Use the default seving signature_key. saved_model_cli.aot_compile_cpu(args) self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix))) self.assertTrue( file_io.file_exists('{}_metadata.o'.format(output_prefix))) self.assertTrue( file_io.file_exists('{}_makefile.inc'.format(output_prefix))) header_contents = file_io.read_file_to_string( '{}.h'.format(output_prefix)) self.assertIn('class Generated', header_contents) self.assertIn('arg_x_data', header_contents) self.assertIn('result_res_data', header_contents) makefile_contents = file_io.read_file_to_string( '{}_makefile.inc'.format(output_prefix)) self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)