def testReadSavedModelInvalid(self): saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model") with self.assertRaisesRegex( IOError, "SavedModel file does not exist at: %s" % saved_model_dir): saved_model_utils.read_saved_model(saved_model_dir)
def _show_defined_functions(saved_model_dir): """Prints the callable concrete and polymorphic functions of the Saved Model. Args: saved_model_dir: Directory containing the SavedModel to inspect. """ meta_graphs = saved_model_utils.read_saved_model(saved_model_dir).meta_graphs has_object_graph_def = False for meta_graph_def in meta_graphs: has_object_graph_def |= meta_graph_def.HasField('object_graph_def') if not has_object_graph_def: return with ops_lib.Graph().as_default(): trackable_object = load.load(saved_model_dir) print('\nDefined Functions:', end='') functions = ( save._AugmentedGraphView(trackable_object) # pylint: disable=protected-access .list_functions(trackable_object)) functions = sorted(functions.items(), key=lambda x: x[0]) for name, function in functions: print('\n Function Name: \'%s\'' % name) concrete_functions = \ function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access concrete_functions = sorted(concrete_functions, key=lambda x: x.name) for index, concrete_function in enumerate(concrete_functions, 1): args, kwargs = concrete_function.structured_input_signature print(' Option #%d' % index) print(' Callable with:') _print_args(args, indent=4) if kwargs: _print_args(kwargs, 'Named Argument', indent=4)
def scan(args): """Function triggered by scan command. Args: args: A namespace parsed from command line. """ if args.tag_set: scan_meta_graph_def( saved_model_utils.get_meta_graph_def(args.dir, args.tag_set)) else: saved_model = saved_model_utils.read_saved_model(args.dir) for meta_graph_def in saved_model.meta_graphs: scan_meta_graph_def(meta_graph_def)
def testReadSavedModelValid(self): saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model") builder = saved_model_builder.SavedModelBuilder(saved_model_dir) with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) builder.save() actual_saved_model_pb = saved_model_utils.read_saved_model(saved_model_dir) self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1) self.assertEqual( len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1) self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0], tag_constants.TRAINING)
def _show_defined_functions(saved_model_dir): """Prints the callable concrete and polymorphic functions of the Saved Model. Args: saved_model_dir: Directory containing the SavedModel to inspect. """ meta_graphs = saved_model_utils.read_saved_model( saved_model_dir).meta_graphs has_object_graph_def = False for meta_graph_def in meta_graphs: has_object_graph_def |= meta_graph_def.HasField('object_graph_def') if not has_object_graph_def: return with ops_lib.Graph().as_default(): trackable_object = load.load(saved_model_dir) print('\nDefined Functions:', end='') children = list( save._AugmentedGraphView(trackable_object) # pylint: disable=protected-access .list_children(trackable_object)) children = sorted(children, key=lambda x: x.name) for name, child in children: concrete_functions = [] if isinstance(child, defun.ConcreteFunction): concrete_functions.append(child) elif isinstance(child, def_function.Function): concrete_functions.extend( child._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access else: continue print('\n Function Name: \'%s\'' % name) concrete_functions = sorted(concrete_functions, key=lambda x: x.name) for index, concrete_function in enumerate(concrete_functions, 1): args, kwargs = None, None if concrete_function.structured_input_signature: args, kwargs = concrete_function.structured_input_signature elif concrete_function._arg_keywords: # pylint: disable=protected-access # For pure ConcreteFunctions we might have nothing better than # _arg_keywords. args = concrete_function._arg_keywords # pylint: disable=protected-access if args: print(' Option #%d' % index) print(' Callable with:') _print_args(args, indent=4) if kwargs: _print_args(kwargs, 'Named Argument', indent=4)
def testReadSavedModelInvalid(self): saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model") with self.assertRaisesRegexp( IOError, "SavedModel file does not exist at: %s" % saved_model_dir): saved_model_utils.read_saved_model(saved_model_dir)