def test_graph_model_to_saved_model_accepts_signature_key_map(self): """graph_model_to_saved_model should accept signature key map""" model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH) export_dir = tempfile.mkdtemp(suffix='.saved_model') try: tags = [tf.saved_model.SERVING] signature_map = { '': {api.SIGNATURE_OUTPUTS: ['Identity']}, 'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}} signature_key = api.RenameMap([ ('Identity', 'output'), ('Identity_1', 'autoencoder_output') ]) api.graph_model_to_saved_model(model_dir, export_dir, tags=tags, signature_def_map=signature_map, signature_key_map=signature_key) # try and load the model meta_graph_def = load_meta_graph(export_dir, tags) # the signatures should contain the renamed keys for signature in meta_graph_def.signature_def.values(): self.assertIn('output', signature.outputs) self.assertEqual(signature.outputs['output'].name, 'Identity:0') signature = meta_graph_def.signature_def['debug'] self.assertIn('autoencoder_output', signature.outputs) self.assertEqual(signature.outputs['autoencoder_output'].name, 'Identity_1:0') finally: if os.path.exists(export_dir): shutil.rmtree(export_dir)
def test_graph_model_to_saved_model_accepts_signature_map(self): """graph_model_to_saved_model should accept signature map""" model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH) export_dir = tempfile.mkdtemp(suffix='.saved_model') try: tags = [tf.saved_model.SERVING] signature_map = { '': {api.SIGNATURE_OUTPUTS: ['Identity']}, 'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}} api.graph_model_to_saved_model(model_dir, export_dir, tags=tags, signature_def_map=signature_map) # try and load the model meta_graph_def = load_meta_graph(export_dir, tags) self.assertIsNotNone(meta_graph_def) # we want both signatures to be present self.assertEqual(len(meta_graph_def.signature_def), 2) # the signatures should be valid for signature in meta_graph_def.signature_def.values(): self.assertTrue(is_valid_signature(signature)) # the default signature should have one output default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY self.assertEqual( len(meta_graph_def.signature_def[default_key].outputs), 1) # debug signature should be present and contain two outputs self.assertIn('debug', meta_graph_def.signature_def.keys()) self.assertEqual( len(meta_graph_def.signature_def['debug'].outputs), 2) finally: if os.path.exists(export_dir): shutil.rmtree(export_dir)
def test_graph_model_to_saved_model(self): """graph_model_to_saved_model should save valid SavedModel""" model_dir = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME) export_dir = tempfile.mkdtemp(suffix='.saved_model') try: tags = ['serving_default'] api.graph_model_to_saved_model(model_dir, export_dir, tags=tags) self.assertTrue(os.path.exists(export_dir)) # must be valid model; tf.saved_model.contains_saved_model is # insufficient imported = tf.saved_model.load(export_dir, tags=tags) self.assertIsNotNone(imported.graph) finally: if os.path.exists(export_dir): shutil.rmtree(export_dir)
def convert(arguments): """ Convert a TensorflowJS-model to a TensorFlow-model. Args: arguments: List of command-line arguments """ args = get_arg_parser().parse_args(arguments) if args.show_version: print("\ntfjs_graph_converter {}\n".format(version.VERSION)) print("Dependency versions:") print(" tensorflow {}".format(tf.version.VERSION)) print(" tensorflowjs {}".format(tfjs.__version__)) return def info(message, end=None): if not args.silence: print(message, end=end) if not args.input_path: raise ValueError( "Missing input_path argument. For usage, use the --help flag.") if not args.output_path: raise ValueError( "Missing output_path argument. For usage, use the --help flag.") info("TensorFlow.js Graph Model Converter\n") info("Graph model: {}".format(args.input_path)) info("Output: {}".format(args.output_path)) info("Target format: {}".format(args.output_format)) info("\nConverting....", end=" ") start_time = time.perf_counter() if args.output_format == common.CLI_FROZEN_MODEL: api.graph_model_to_frozen_graph(args.input_path, args.output_path) elif args.output_format == common.CLI_SAVED_MODEL: api.graph_model_to_saved_model(args.input_path, args.output_path, args.saved_model_tags) else: raise ValueError("Unsupported output format: {}".format( args.output_format)) end_time = time.perf_counter() info("Done.") info("Conversion took {0:.3f}s".format(end_time - start_time)) return
def test_graph_model_to_saved_model(self): """graph_model_to_saved_model should save valid SavedModel""" model_dir = testutils.get_path_to(testutils.PRELU_MODEL_PATH) export_dir = tempfile.mkdtemp(suffix='.saved_model') try: tags = [tf.saved_model.SERVING] api.graph_model_to_saved_model(model_dir, export_dir, tags=tags) self.assertTrue(os.path.exists(export_dir)) self.assertTrue(tf.saved_model.contains_saved_model(export_dir)) # try and load the model meta_graph_def = load_meta_graph(export_dir, tags) self.assertIsNotNone(meta_graph_def) # we also want a signature to be present self.assertEqual(len(meta_graph_def.signature_def), 1) # the signatures should be valid self.assertTrue( is_valid_signature( list(meta_graph_def.signature_def.values())[0])) finally: if os.path.exists(export_dir): shutil.rmtree(export_dir)
def convert(arguments): """ Convert a TensorflowJS-model to a TensorFlow-model. Args: arguments: List of command-line arguments """ args = get_arg_parser().parse_args(arguments) if args.show_version: print(f"\ntfjs_graph_converter {version.VERSION}\n") print("Dependency versions:") print(f" tensorflow {tf.version.VERSION}") print(f" tensorflowjs {tfjs.__version__}") return def info(message, end=None): if not args.silence: print(message, end=end, flush=True) if not args.input_path: raise ValueError( "Missing input_path argument. For usage, use the --help flag.") if not args.output_path: raise ValueError( "Missing output_path argument. For usage, use the --help flag.") if args.output_format == common.CLI_SAVED_MODEL: if args.signature_key is not None and args.outputs is None: raise ValueError(f'--{common.CLI_SIGNATURE_KEY} requires ' f'--{common.CLI_OUTPUTS} to be specified') if args.method_name is not None and args.outputs is None: raise ValueError(f'--{common.CLI_METHOD_NAME} requires ' f'--{common.CLI_OUTPUTS} to be specified') if args.rename is not None and args.outputs is None: raise ValueError(f'--{common.CLI_RENAME} requires ' f'--{common.CLI_OUTPUTS} to be specified') info("TensorFlow.js Graph Model Converter\n") info(f"Graph model: {args.input_path}") info(f"Output: {args.output_path}") info(f"Target format: {args.output_format}") info("\nConverting....", end=" ") start_time = time.perf_counter() if args.output_format == common.CLI_FROZEN_MODEL: api.graph_model_to_frozen_graph(args.input_path, args.output_path, args.compat_mode) elif args.output_format == common.CLI_SAVED_MODEL: api.graph_model_to_saved_model( args.input_path, args.output_path, tags=args.saved_model_tags, signature_def_map=_get_signature(args), signature_key_map=_get_signature_keys(args), compat_mode=args.compat_mode) else: raise ValueError(f"Unsupported output format: {args.output_format}") end_time = time.perf_counter() info("Done.") info(f"Conversion took {end_time-start_time:.3f}s")