示例#1
0
    def test_convert_frozen_model(self):
        self.create_frozen_model()
        print(glob.glob(os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, '*')))

        tf_saved_model_conversion_v2.convert_tf_frozen_model(
            os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, 'model.frozen'),
            'Softmax', os.path.join(self._tmp_dir, FROZEN_MODEL_DIR))

        tfjs_path = os.path.join(self._tmp_dir, FROZEN_MODEL_DIR)
        # Check model.json and weights manifest.
        with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
            model_json = json.load(f)
        self.assertTrue(model_json['modelTopology'])
        self.assertIsNot(model_json['modelTopology']['versions'], None)
        signature = model_json['userDefinedMetadata']['signature']
        self.assertIsNot(signature, None)
        # frozen model signature has no input nodes.
        self.assertIsNot(signature['outputs'], None)

        weights_manifest = model_json['weightsManifest']
        self.assertCountEqual(weights_manifest[0]['paths'],
                              ['group1-shard1of1.bin'])
        self.assertIn('weights', weights_manifest[0])
        self.assertTrue(
            glob.glob(os.path.join(self._tmp_dir, FROZEN_MODEL_DIR,
                                   'group*-*')))
def convert_savedmodel(savedmodel_dir, output_nodes, output_dir, tags=None):
    """
	Given a saved model, desired output nodes, and tags, convert to tfjs
	"""
    # convert to frozen graph file
    frozen_filename = os.path.join(savedmodel_dir, 'tmp_frozengraph.pb')
    try:
        freeze_savedmodel(savedmodel_dir=savedmodel_dir,
                          output_nodes=output_nodes,
                          frozen_file=frozen_filename,
                          tags=tags)
    except Exception as e:
        print(f"Error freezing graph: {e}")
        if os.path.isfile(frozen_filename):
            os.remove(frozen_filename)
        raise

    # now convert the frozen graph to the tfjs model
    try:
        convert_tf_frozen_model(frozen_model_path=frozen_filename,
                                output_node_names=",".join(output_nodes),
                                output_dir=output_dir)
    except Exception as e:
        print(f"Error converting frozen graph: {e}")
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        raise
    finally:
        if os.path.exists(frozen_filename):
            os.remove(frozen_filename)
def convert_localization(frozen_model, labels_path, output_path):
    tf_saved_model_conversion_v2.convert_tf_frozen_model(
        frozen_model,
        "Postprocessor/ExpandDims_1,Postprocessor/Slice",
        output_path,
        quantization_dtype=None,
        skip_op_check=False,
        strip_debug_ops=True,
    )

    # Move the labels to the model directory.
    shutil.copy2(labels_path, output_path)
def convert_classification(frozen_model, labels_path, output_path):
    tf_saved_model_conversion_v2.convert_tf_frozen_model(
        frozen_model,
        "final_result",
        output_path,
        quantization_dtype=None,
        skip_op_check=False,
        strip_debug_ops=True,
    )

    # Move the labels to the model directory.
    with open(labels_path, "r") as f:
        labels = f.read()
        labels = list(filter(bool, [s.strip() for s in labels.splitlines()]))
    with open(os.path.join(output_path, "labels.json"), "w") as f:
        json.dump(labels, f)
    def test_convert_frozen_model_with_metadata(self):
        self.create_frozen_model()
        print(glob.glob(os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, '*')))

        metadata_json = {'a': 1}
        tf_saved_model_conversion_v2.convert_tf_frozen_model(
            os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, 'model.frozen'),
            'Softmax',
            os.path.join(self._tmp_dir, FROZEN_MODEL_DIR),
            metadata={'key': metadata_json})

        tfjs_path = os.path.join(self._tmp_dir, FROZEN_MODEL_DIR)
        # Check model.json and weights manifest.
        with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
            model_json = json.load(f)
        self.assertEqual(metadata_json,
                         model_json['userDefinedMetadata']['key'])
  def test_convert_frozen_model(self):
    self.create_frozen_model()
    print(glob.glob(
        os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, '*')))

    tf_saved_model_conversion_v2.convert_tf_frozen_model(
        os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, 'model.frozen'),
        'Softmax',
        os.path.join(self._tmp_dir, FROZEN_MODEL_DIR))

    tfjs_path = os.path.join(self._tmp_dir, FROZEN_MODEL_DIR)
    # Check model.json and weights manifest.
    with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
      model_json = json.load(f)
    self.assertTrue(model_json['modelTopology'])
    weights_manifest = model_json['weightsManifest']
    weights_manifest = model_json['weightsManifest']
    self.assertCountEqual(weights_manifest[0]['paths'],
                          ['group1-shard1of1.bin'])
    self.assertIn('weights', weights_manifest[0])
    self.assertTrue(
        glob.glob(
            os.path.join(self._tmp_dir, FROZEN_MODEL_DIR, 'group*-*')))
示例#7
0
def convert(arguments):
  args = get_arg_parser().parse_args(arguments)
  if args.show_version:
    print('\ntensorflowjs %s\n' % version.version)
    print('Dependency versions:')
    print('  keras %s' % tf.keras.__version__)
    print('  tensorflow %s' % tf.__version__)
    return

  if not args.input_path:
    raise ValueError(
        'Missing input_path argument. For usage, use the --help flag.')
  if not args.output_path:
    raise ValueError(
        'Missing output_path argument. For usage, use the --help flag.')

  if args.input_path is None:
    raise ValueError(
        'Error: The input_path argument must be set. '
        'Run with --help flag for usage information.')

  input_format, output_format = _standardize_input_output_formats(
      args.input_format, args.output_format)

  weight_shard_size_bytes = 1024 * 1024 * 4
  if args.weight_shard_size_bytes is not None:
    if (output_format not in
        (common.TFJS_LAYERS_MODEL, common.TFJS_GRAPH_MODEL)):
      raise ValueError(
          'The --weight_shard_size_bytes flag is only supported when '
          'output_format is tfjs_layers_model or tfjs_graph_model.')

    if not (isinstance(args.weight_shard_size_bytes, int) and
            args.weight_shard_size_bytes > 0):
      raise ValueError(
          'Expected weight_shard_size_bytes to be a positive integer, '
          'but got %s' % args.weight_shard_size_bytes)
    weight_shard_size_bytes = args.weight_shard_size_bytes

  quantization_dtype = (
      quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
      if args.quantization_bytes else None)

  if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL):
    raise ValueError(
        'The --output_node_names flag is required for "tf_frozen_model"')

  if (args.signature_name and input_format not in
      (common.TF_SAVED_MODEL, common.TF_HUB_MODEL)):
    raise ValueError(
        'The --signature_name flag is applicable only to "tf_saved_model" and '
        '"tf_hub" input format, but the current input format is '
        '"%s".' % input_format)

  # TODO(cais, piyu): More conversion logics can be added as additional
  #   branches below.
  if (input_format == common.KERAS_MODEL and
      output_format == common.TFJS_LAYERS_MODEL):
    dispatch_keras_h5_to_tfjs_layers_model_conversion(
        args.input_path, output_dir=args.output_path,
        quantization_dtype=quantization_dtype,
        split_weights_by_layer=args.split_weights_by_layer,
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.KERAS_MODEL and
        output_format == common.TFJS_GRAPH_MODEL):
    dispatch_keras_h5_to_tfjs_graph_model_conversion(
        args.input_path, output_dir=args.output_path,
        quantization_dtype=quantization_dtype,
        skip_op_check=args.skip_op_check,
        strip_debug_ops=args.strip_debug_ops,
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.KERAS_SAVED_MODEL and
        output_format == common.TFJS_LAYERS_MODEL):
    dispatch_keras_saved_model_to_tensorflowjs_conversion(
        args.input_path, args.output_path,
        quantization_dtype=quantization_dtype,
        split_weights_by_layer=args.split_weights_by_layer,
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.TF_SAVED_MODEL and
        output_format == common.TFJS_GRAPH_MODEL):
    tf_saved_model_conversion_v2.convert_tf_saved_model(
        args.input_path, args.output_path,
        signature_def=args.signature_name,
        saved_model_tags=args.saved_model_tags,
        quantization_dtype=quantization_dtype,
        skip_op_check=args.skip_op_check,
        strip_debug_ops=args.strip_debug_ops,
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.TF_HUB_MODEL and
        output_format == common.TFJS_GRAPH_MODEL):
    tf_saved_model_conversion_v2.convert_tf_hub_module(
        args.input_path, args.output_path,
        signature=args.signature_name,
        saved_model_tags=args.saved_model_tags,
        quantization_dtype=quantization_dtype,
        skip_op_check=args.skip_op_check,
        strip_debug_ops=args.strip_debug_ops,
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.TFJS_LAYERS_MODEL and
        output_format == common.KERAS_MODEL):
    dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
                                                 args.output_path)
  elif (input_format == common.TFJS_LAYERS_MODEL and
        output_format == common.KERAS_SAVED_MODEL):
    dispatch_tensorflowjs_to_keras_saved_model_conversion(args.input_path,
                                                          args.output_path)
  elif (input_format == common.TFJS_LAYERS_MODEL and
        output_format == common.TFJS_LAYERS_MODEL):
    dispatch_tensorflowjs_to_tensorflowjs_conversion(
        args.input_path, args.output_path,
        quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.TFJS_LAYERS_MODEL and
        output_format == common.TFJS_GRAPH_MODEL):
    dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
        args.input_path, args.output_path,
        quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
        skip_op_check=args.skip_op_check,
        strip_debug_ops=args.strip_debug_ops,
        weight_shard_size_bytes=weight_shard_size_bytes)
  elif (input_format == common.TF_FROZEN_MODEL and
        output_format == common.TFJS_GRAPH_MODEL):
    tf_saved_model_conversion_v2.convert_tf_frozen_model(
        args.input_path, args.output_node_names, args.output_path,
        quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
        skip_op_check=args.skip_op_check,
        strip_debug_ops=args.strip_debug_ops,
        weight_shard_size_bytes=weight_shard_size_bytes)
  else:
    raise ValueError(
        'Unsupported input_format - output_format pair: %s - %s' %
        (input_format, output_format))
示例#8
0
文件: converter.py 项目: caisq/tfjs-1
def _dispatch_converter(input_format, output_format, args,
                        quantization_dtype_map, weight_shard_size_bytes,
                        metadata_map):
    # TODO(cais, piyu): More conversion logics can be added as additional
    #   branches below.
    if (input_format == common.KERAS_MODEL
            and output_format == common.TFJS_LAYERS_MODEL):
        dispatch_keras_h5_to_tfjs_layers_model_conversion(
            args.input_path,
            output_dir=args.output_path,
            quantization_dtype_map=quantization_dtype_map,
            split_weights_by_layer=args.split_weights_by_layer,
            weight_shard_size_bytes=weight_shard_size_bytes,
            metadata=metadata_map)
    elif (input_format == common.KERAS_MODEL
          and output_format == common.TFJS_GRAPH_MODEL):
        dispatch_keras_h5_to_tfjs_graph_model_conversion(
            args.input_path,
            output_dir=args.output_path,
            quantization_dtype_map=quantization_dtype_map,
            skip_op_check=args.skip_op_check,
            strip_debug_ops=args.strip_debug_ops,
            weight_shard_size_bytes=weight_shard_size_bytes,
            control_flow_v2=args.control_flow_v2,
            experiments=args.experiments,
            metadata=metadata_map)
    elif (input_format == common.KERAS_SAVED_MODEL
          and output_format == common.TFJS_LAYERS_MODEL):
        dispatch_keras_saved_model_to_tensorflowjs_conversion(
            args.input_path,
            args.output_path,
            quantization_dtype_map=quantization_dtype_map,
            split_weights_by_layer=args.split_weights_by_layer,
            weight_shard_size_bytes=weight_shard_size_bytes,
            metadata=metadata_map)
    elif (input_format == common.TF_SAVED_MODEL
          and output_format == common.TFJS_GRAPH_MODEL):
        tf_saved_model_conversion_v2.convert_tf_saved_model(
            args.input_path,
            args.output_path,
            signature_def=args.signature_name,
            saved_model_tags=args.saved_model_tags,
            quantization_dtype_map=quantization_dtype_map,
            skip_op_check=args.skip_op_check,
            strip_debug_ops=args.strip_debug_ops,
            weight_shard_size_bytes=weight_shard_size_bytes,
            control_flow_v2=args.control_flow_v2,
            experiments=args.experiments,
            metadata=metadata_map)
    elif (input_format == common.TF_HUB_MODEL
          and output_format == common.TFJS_GRAPH_MODEL):
        tf_saved_model_conversion_v2.convert_tf_hub_module(
            args.input_path,
            args.output_path,
            signature=args.signature_name,
            saved_model_tags=args.saved_model_tags,
            quantization_dtype_map=quantization_dtype_map,
            skip_op_check=args.skip_op_check,
            strip_debug_ops=args.strip_debug_ops,
            weight_shard_size_bytes=weight_shard_size_bytes,
            control_flow_v2=args.control_flow_v2,
            experiments=args.experiments,
            metadata=metadata_map)
    elif (input_format == common.TFJS_LAYERS_MODEL
          and output_format == common.KERAS_MODEL):
        dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
                                                     args.output_path)
    elif (input_format == common.TFJS_LAYERS_MODEL
          and output_format == common.KERAS_SAVED_MODEL):
        dispatch_tensorflowjs_to_keras_saved_model_conversion(
            args.input_path, args.output_path)
    elif (input_format == common.TFJS_LAYERS_MODEL
          and output_format == common.TFJS_LAYERS_MODEL):
        dispatch_tensorflowjs_to_tensorflowjs_conversion(
            args.input_path,
            args.output_path,
            quantization_dtype_map=quantization_dtype_map,
            weight_shard_size_bytes=weight_shard_size_bytes)
    elif (input_format == common.TFJS_LAYERS_MODEL
          and output_format == common.TFJS_GRAPH_MODEL):
        dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
            args.input_path,
            args.output_path,
            quantization_dtype_map=quantization_dtype_map,
            skip_op_check=args.skip_op_check,
            strip_debug_ops=args.strip_debug_ops,
            weight_shard_size_bytes=weight_shard_size_bytes,
            control_flow_v2=args.control_flow_v2,
            experiments=args.experiments,
            metadata=metadata_map)
    elif (input_format == common.TF_FROZEN_MODEL
          and output_format == common.TFJS_GRAPH_MODEL):
        tf_saved_model_conversion_v2.convert_tf_frozen_model(
            args.input_path,
            args.output_node_names,
            args.output_path,
            quantization_dtype_map=quantization_dtype_map,
            skip_op_check=args.skip_op_check,
            strip_debug_ops=args.strip_debug_ops,
            weight_shard_size_bytes=weight_shard_size_bytes,
            experiments=args.experiments,
            metadata=metadata_map)
    else:
        raise ValueError(
            'Unsupported input_format - output_format pair: %s - %s' %
            (input_format, output_format))