def testReadSavedModelInvalid(self):
     saved_model_dir = os.path.join(test.get_temp_dir(),
                                    "invalid_saved_model")
     with self.assertRaisesRegex(
             IOError,
             "SavedModel file does not exist at: %s" % saved_model_dir):
         saved_model_utils.read_saved_model(saved_model_dir)
def _show_defined_functions(saved_model_dir):
  """Prints the callable concrete and polymorphic functions of the Saved Model.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
  """
  meta_graphs = saved_model_utils.read_saved_model(saved_model_dir).meta_graphs
  has_object_graph_def = False

  for meta_graph_def in meta_graphs:
    has_object_graph_def |= meta_graph_def.HasField('object_graph_def')
  if not has_object_graph_def:
    return
  with ops_lib.Graph().as_default():
    trackable_object = load.load(saved_model_dir)

  print('\nDefined Functions:', end='')
  functions = (
      save._AugmentedGraphView(trackable_object)  # pylint: disable=protected-access
      .list_functions(trackable_object))
  functions = sorted(functions.items(), key=lambda x: x[0])
  for name, function in functions:
    print('\n  Function Name: \'%s\'' % name)
    concrete_functions = \
        function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    concrete_functions = sorted(concrete_functions, key=lambda x: x.name)
    for index, concrete_function in enumerate(concrete_functions, 1):
      args, kwargs = concrete_function.structured_input_signature
      print('    Option #%d' % index)
      print('      Callable with:')
      _print_args(args, indent=4)
      if kwargs:
        _print_args(kwargs, 'Named Argument', indent=4)
示例#3
0
def scan(args):
    """Function triggered by scan command.

  Args:
    args: A namespace parsed from command line.
  """
    if args.tag_set:
        scan_meta_graph_def(
            saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
    else:
        saved_model = saved_model_utils.read_saved_model(args.dir)
        for meta_graph_def in saved_model.meta_graphs:
            scan_meta_graph_def(meta_graph_def)
示例#4
0
def scan(args):
  """Function triggered by scan command.

  Args:
    args: A namespace parsed from command line.
  """
  if args.tag_set:
    scan_meta_graph_def(
        saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
  else:
    saved_model = saved_model_utils.read_saved_model(args.dir)
    for meta_graph_def in saved_model.meta_graphs:
      scan_meta_graph_def(meta_graph_def)
  def testReadSavedModelValid(self):
    saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
    with self.session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
    builder.save()

    actual_saved_model_pb = saved_model_utils.read_saved_model(saved_model_dir)
    self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
    self.assertEqual(
        len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
    self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
                     tag_constants.TRAINING)
示例#6
0
  def testReadSavedModelValid(self):
    saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
    with self.session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
    builder.save()

    actual_saved_model_pb = saved_model_utils.read_saved_model(saved_model_dir)
    self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
    self.assertEqual(
        len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
    self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
                     tag_constants.TRAINING)
示例#7
0
def _show_defined_functions(saved_model_dir):
    """Prints the callable concrete and polymorphic functions of the Saved Model.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
  """
    meta_graphs = saved_model_utils.read_saved_model(
        saved_model_dir).meta_graphs
    has_object_graph_def = False

    for meta_graph_def in meta_graphs:
        has_object_graph_def |= meta_graph_def.HasField('object_graph_def')
    if not has_object_graph_def:
        return
    with ops_lib.Graph().as_default():
        trackable_object = load.load(saved_model_dir)

    print('\nDefined Functions:', end='')
    children = list(
        save._AugmentedGraphView(trackable_object)  # pylint: disable=protected-access
        .list_children(trackable_object))
    children = sorted(children, key=lambda x: x.name)
    for name, child in children:
        concrete_functions = []
        if isinstance(child, defun.ConcreteFunction):
            concrete_functions.append(child)
        elif isinstance(child, def_function.Function):
            concrete_functions.extend(
                child._list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
        else:
            continue
        print('\n  Function Name: \'%s\'' % name)
        concrete_functions = sorted(concrete_functions, key=lambda x: x.name)
        for index, concrete_function in enumerate(concrete_functions, 1):
            args, kwargs = None, None
            if concrete_function.structured_input_signature:
                args, kwargs = concrete_function.structured_input_signature
            elif concrete_function._arg_keywords:  # pylint: disable=protected-access
                # For pure ConcreteFunctions we might have nothing better than
                # _arg_keywords.
                args = concrete_function._arg_keywords  # pylint: disable=protected-access
            if args:
                print('    Option #%d' % index)
                print('      Callable with:')
                _print_args(args, indent=4)
            if kwargs:
                _print_args(kwargs, 'Named Argument', indent=4)
 def testReadSavedModelInvalid(self):
   saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model")
   with self.assertRaisesRegexp(
       IOError, "SavedModel file does not exist at: %s" % saved_model_dir):
     saved_model_utils.read_saved_model(saved_model_dir)