Exemplo n.º 1
0
  def testFreezeModel(self):
    if not test.is_built_with_xla():
      self.skipTest('Skipping test because XLA is not compiled in.')

    variables_to_feed = 'all'
    func = 'func2'
    saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
    dummy_model = self.AOTCompileDummyModel()
    func = getattr(dummy_model, func)
    with self.cached_session():
      self.evaluate(dummy_model.var.initializer)
      self.evaluate(dummy_model.write_var.initializer)
      save.save(dummy_model, saved_model_dir, signatures={'func': func})

    self.parser = saved_model_cli.create_parser()
    output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
    args = [  # Use the default seving signature_key.
        'freeze_model', '--dir', saved_model_dir, '--tag_set', 'serve',
        '--signature_def_key', 'func', '--output_prefix', output_prefix,
        '--variables_to_feed', variables_to_feed
    ]
    args = self.parser.parse_args(args)
    with test.mock.patch.object(logging, 'warn'):
      saved_model_cli.freeze_model(args)
    self.assertTrue(
        file_io.file_exists(os.path.join(output_prefix, 'frozen_graph.pb')))
    self.assertTrue(
        file_io.file_exists(os.path.join(output_prefix, 'config.pbtxt')))
Exemplo n.º 2
0
    def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed, func,
                                            target_triple):
        if not test.is_built_with_xla():
            self.skipTest('Skipping test because XLA is not compiled in.')

        saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
        dummy_model = self.AOTCompileDummyModel()
        func = getattr(dummy_model, func)
        with self.cached_session():
            self.evaluate(dummy_model.var.initializer)
            self.evaluate(dummy_model.write_var.initializer)
            save.save(dummy_model, saved_model_dir, signatures={'func': func})

        self.parser = saved_model_cli.create_parser()
        output_prefix = os.path.join(test.get_temp_dir(),
                                     'aot_compile_cpu_dir/out')
        args = [  # Use the default seving signature_key.
            'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
            '--signature_def_key', 'func', '--output_prefix', output_prefix,
            '--variables_to_feed', variables_to_feed, '--cpp_class',
            'Generated'
        ]
        if target_triple:
            args.extend(['--target_triple', target_triple])
        args = self.parser.parse_args(args)
        with test.mock.patch.object(logging, 'warn') as captured_warn:
            saved_model_cli.aot_compile_cpu(args)
        self.assertRegex(
            str(captured_warn.call_args),
            'Signature input key \'y\'.*has been pruned while freezing the graph.'
        )
        self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
        self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
        self.assertTrue(
            file_io.file_exists('{}_metadata.o'.format(output_prefix)))
        self.assertTrue(
            file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
        header_contents = file_io.read_file_to_string(
            '{}.h'.format(output_prefix))
        self.assertIn('class Generated', header_contents)
        self.assertIn('arg_feed_x_data', header_contents)
        self.assertIn('result_fetch_res_data', header_contents)
        # arg_y got filtered out as it's not used by the output.
        self.assertNotIn('arg_feed_y_data', header_contents)
        if variables_to_feed:
            # Read-only-variables' setters preserve constness.
            self.assertIn('set_var_param_my_var_data(const float',
                          header_contents)
            self.assertNotIn('set_var_param_my_var_data(float',
                             header_contents)
        if func == dummy_model.func_write:
            # Writeable variables setters do not preserve constness.
            self.assertIn('set_var_param_write_var_data(float',
                          header_contents)
            self.assertNotIn('set_var_param_write_var_data(const float',
                             header_contents)

        makefile_contents = file_io.read_file_to_string(
            '{}_makefile.inc'.format(output_prefix))
        self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
Exemplo n.º 3
0
  def testAOTCompileCPUWrongSignatureDefKey(self):
    if not test.is_built_with_xla():
      self.skipTest('Skipping test because XLA is not compiled in.')

    self.parser = saved_model_cli.create_parser()
    base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
    output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
    args = self.parser.parse_args([
        'aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
        '--output_prefix', output_dir, '--cpp_class', 'Compiled',
        '--signature_def_key', 'MISSING'
    ])
    with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'):
      saved_model_cli.aot_compile_cpu(args)
Exemplo n.º 4
0
    def testAOTCompileCPUFreezesAndCompiles(self):
        if not test.is_built_with_xla():
            self.skipTest('Skipping test because XLA is not compiled in.')

        class DummyModel(tracking.AutoTrackable):
            """Model compatible with XLA compilation."""
            def __init__(self):
                self.var = variables.Variable(1.0, name='my_var')

            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
            ])
            def func2(self, x):
                return {'res': x + self.var}

        saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
        dummy_model = DummyModel()
        with self.cached_session():
            self.evaluate(dummy_model.var.initializer)
            save.save(dummy_model, saved_model_dir)

        self.parser = saved_model_cli.create_parser()
        output_prefix = os.path.join(test.get_temp_dir(),
                                     'aot_compile_cpu_dir/out')
        args = self.parser.parse_args([
            'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
            '--output_prefix', output_prefix, '--cpp_class', 'Generated'
        ])  # Use the default seving signature_key.
        saved_model_cli.aot_compile_cpu(args)
        self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
        self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
        self.assertTrue(
            file_io.file_exists('{}_metadata.o'.format(output_prefix)))
        self.assertTrue(
            file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
        header_contents = file_io.read_file_to_string(
            '{}.h'.format(output_prefix))
        self.assertIn('class Generated', header_contents)
        self.assertIn('arg_x_data', header_contents)
        self.assertIn('result_res_data', header_contents)
        makefile_contents = file_io.read_file_to_string(
            '{}_makefile.inc'.format(output_prefix))
        self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)