Пример #1
0
 def test_graph_model_to_saved_model_accepts_signature_key_map(self):
     """graph_model_to_saved_model should accept signature key map"""
     model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = [tf.saved_model.SERVING]
         signature_map = {
             '': {api.SIGNATURE_OUTPUTS: ['Identity']},
             'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}}
         signature_key = api.RenameMap([
             ('Identity', 'output'), ('Identity_1', 'autoencoder_output')
         ])
         api.graph_model_to_saved_model(model_dir, export_dir,
                                        tags=tags,
                                        signature_def_map=signature_map,
                                        signature_key_map=signature_key)
         # try and load the model
         meta_graph_def = load_meta_graph(export_dir, tags)
         # the signatures should contain the renamed keys
         for signature in meta_graph_def.signature_def.values():
             self.assertIn('output', signature.outputs)
             self.assertEqual(signature.outputs['output'].name,
                              'Identity:0')
         signature = meta_graph_def.signature_def['debug']
         self.assertIn('autoencoder_output', signature.outputs)
         self.assertEqual(signature.outputs['autoencoder_output'].name,
                          'Identity_1:0')
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Пример #2
0
 def test_graph_model_to_saved_model_accepts_signature_map(self):
     """graph_model_to_saved_model should accept signature map"""
     model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = [tf.saved_model.SERVING]
         signature_map = {
             '': {api.SIGNATURE_OUTPUTS: ['Identity']},
             'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}}
         api.graph_model_to_saved_model(model_dir, export_dir,
                                        tags=tags,
                                        signature_def_map=signature_map)
         # try and load the model
         meta_graph_def = load_meta_graph(export_dir, tags)
         self.assertIsNotNone(meta_graph_def)
         # we want both signatures to be present
         self.assertEqual(len(meta_graph_def.signature_def), 2)
         # the signatures should be valid
         for signature in meta_graph_def.signature_def.values():
             self.assertTrue(is_valid_signature(signature))
         # the default signature should have one output
         default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
         self.assertEqual(
             len(meta_graph_def.signature_def[default_key].outputs), 1)
         # debug signature should be present and contain two outputs
         self.assertIn('debug', meta_graph_def.signature_def.keys())
         self.assertEqual(
             len(meta_graph_def.signature_def['debug'].outputs), 2)
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Пример #3
0
 def test_graph_model_to_saved_model(self):
     """graph_model_to_saved_model should save valid SavedModel"""
     model_dir = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = ['serving_default']
         api.graph_model_to_saved_model(model_dir, export_dir, tags=tags)
         self.assertTrue(os.path.exists(export_dir))
         # must be valid model; tf.saved_model.contains_saved_model is
         # insufficient
         imported = tf.saved_model.load(export_dir, tags=tags)
         self.assertIsNotNone(imported.graph)
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Пример #4
0
def convert(arguments):
    """
    Convert a TensorflowJS-model to a TensorFlow-model.

    Args:
        arguments: List of command-line arguments   
    """
    args = get_arg_parser().parse_args(arguments)
    if args.show_version:
        print("\ntfjs_graph_converter {}\n".format(version.VERSION))
        print("Dependency versions:")
        print("    tensorflow {}".format(tf.version.VERSION))
        print("    tensorflowjs {}".format(tfjs.__version__))
        return

    def info(message, end=None):
        if not args.silence:
            print(message, end=end)

    if not args.input_path:
        raise ValueError(
            "Missing input_path argument. For usage, use the --help flag.")
    if not args.output_path:
        raise ValueError(
            "Missing output_path argument. For usage, use the --help flag.")

    info("TensorFlow.js Graph Model Converter\n")
    info("Graph model:    {}".format(args.input_path))
    info("Output:         {}".format(args.output_path))
    info("Target format:  {}".format(args.output_format))
    info("\nConverting....", end=" ")

    start_time = time.perf_counter()

    if args.output_format == common.CLI_FROZEN_MODEL:
        api.graph_model_to_frozen_graph(args.input_path, args.output_path)
    elif args.output_format == common.CLI_SAVED_MODEL:
        api.graph_model_to_saved_model(args.input_path, args.output_path,
                                       args.saved_model_tags)
    else:
        raise ValueError("Unsupported output format: {}".format(
            args.output_format))

    end_time = time.perf_counter()
    info("Done.")
    info("Conversion took {0:.3f}s".format(end_time - start_time))

    return
Пример #5
0
 def test_graph_model_to_saved_model(self):
     """graph_model_to_saved_model should save valid SavedModel"""
     model_dir = testutils.get_path_to(testutils.PRELU_MODEL_PATH)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = [tf.saved_model.SERVING]
         api.graph_model_to_saved_model(model_dir, export_dir, tags=tags)
         self.assertTrue(os.path.exists(export_dir))
         self.assertTrue(tf.saved_model.contains_saved_model(export_dir))
         # try and load the model
         meta_graph_def = load_meta_graph(export_dir, tags)
         self.assertIsNotNone(meta_graph_def)
         # we also want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         # the signatures should be valid
         self.assertTrue(
             is_valid_signature(
                 list(meta_graph_def.signature_def.values())[0]))
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Пример #6
0
def convert(arguments):
    """
    Convert a TensorflowJS-model to a TensorFlow-model.

    Args:
        arguments: List of command-line arguments
    """
    args = get_arg_parser().parse_args(arguments)
    if args.show_version:
        print(f"\ntfjs_graph_converter {version.VERSION}\n")
        print("Dependency versions:")
        print(f"    tensorflow {tf.version.VERSION}")
        print(f"    tensorflowjs {tfjs.__version__}")
        return

    def info(message, end=None):
        if not args.silence:
            print(message, end=end, flush=True)

    if not args.input_path:
        raise ValueError(
            "Missing input_path argument. For usage, use the --help flag.")
    if not args.output_path:
        raise ValueError(
            "Missing output_path argument. For usage, use the --help flag.")
    if args.output_format == common.CLI_SAVED_MODEL:
        if args.signature_key is not None and args.outputs is None:
            raise ValueError(f'--{common.CLI_SIGNATURE_KEY} requires '
                             f'--{common.CLI_OUTPUTS} to be specified')
        if args.method_name is not None and args.outputs is None:
            raise ValueError(f'--{common.CLI_METHOD_NAME} requires '
                             f'--{common.CLI_OUTPUTS} to be specified')
        if args.rename is not None and args.outputs is None:
            raise ValueError(f'--{common.CLI_RENAME} requires '
                             f'--{common.CLI_OUTPUTS} to be specified')

    info("TensorFlow.js Graph Model Converter\n")
    info(f"Graph model:    {args.input_path}")
    info(f"Output:         {args.output_path}")
    info(f"Target format:  {args.output_format}")
    info("\nConverting....", end=" ")

    start_time = time.perf_counter()

    if args.output_format == common.CLI_FROZEN_MODEL:
        api.graph_model_to_frozen_graph(args.input_path, args.output_path,
                                        args.compat_mode)
    elif args.output_format == common.CLI_SAVED_MODEL:
        api.graph_model_to_saved_model(
            args.input_path,
            args.output_path,
            tags=args.saved_model_tags,
            signature_def_map=_get_signature(args),
            signature_key_map=_get_signature_keys(args),
            compat_mode=args.compat_mode)
    else:
        raise ValueError(f"Unsupported output format: {args.output_format}")

    end_time = time.perf_counter()
    info("Done.")
    info(f"Conversion took {end_time-start_time:.3f}s")