コード例 #1
0
    def testReadWeightsWithIncorrectTypeInWeightsManifestRaisesError(self):
        groups = [[{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]]

        write_weights.write_weights(groups, self._tmp_dir)

        with self.assertRaises(ValueError):
            read_weights.read_weights(groups[0][0], self._tmp_dir)
コード例 #2
0
  def testReadWeightsWithIncorrectTypeInWeightsManifestRaisesError(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]
    ]

    write_weights.write_weights(groups, self._tmp_dir)

    with self.assertRaises(ValueError):
      read_weights.read_weights(groups[0][0], self._tmp_dir)
コード例 #3
0
    def testReadOneGroupEmptyStrings(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array(['', ''], 'object')
        }, {
            'name': 'weight2',
            'data': np.array([], 'object')
        }, {
            'name': 'weight3',
            'data': np.array([[]], 'object')
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        group = read_output[0]
        self.assertEqual(3, len(group))

        weight1 = group[0]
        self.assertEqual('weight1', weight1['name'])
        np.testing.assert_array_equal(
            weight1['data'],
            np.array([u''.encode('utf-8'), u''.encode('utf-8')], 'object'))

        weight2 = group[1]
        self.assertEqual('weight2', weight2['name'])
        np.testing.assert_array_equal(weight2['data'], np.array([], 'object'))

        weight3 = group[2]
        self.assertEqual('weight3', weight3['name'])
        np.testing.assert_array_equal(weight3['data'], np.array([[]],
                                                                'object'))
コード例 #4
0
  def testReadCyrillicStringUnicodeAndEncoded(self):
    groups = [
        [{
            'name': 'weight1',
            # String is stored as unicode.
            'data': np.array([u'здраво'], 'object')
        }, {
            'name': 'weight2',
            # String is stored encoded.
            'data': np.array([u'поздрав'.encode('utf-8')], 'object')
        }]
    ]

    manifest = write_weights.write_weights(groups, self._tmp_dir)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    group = read_output[0]
    self.assertEqual(2, len(group))

    weight1 = group[0]
    self.assertEqual('weight1', weight1['name'])
    np.testing.assert_array_equal(
        weight1['data'],
        np.array([u'здраво'.encode('utf-8')], 'object'))

    weight2 = group[1]
    self.assertEqual('weight2', weight2['name'])
    np.testing.assert_array_equal(
        weight2['data'],
        np.array([u'поздрав'.encode('utf-8')], 'object'))
コード例 #5
0
ファイル: api.py プロジェクト: ducky777/posenet-python
def _convert_graph_model_to_graph(model_json: Dict[str, Any],
                                  base_path: str) -> tf.Graph:
    """
    Convert TFJS JSON model to TF Graph

    Args:
        model_json: JSON dict from TFJS model file
        base_path:  Path to the model file (where to find the model weights)

    Returns:
        TF Graph for inference or saving
    """
    if ARTIFACT_MODEL_TOPOLOGY_KEY not in model_json:
        raise ValueError(
            f"model_json is missing key '{ARTIFACT_MODEL_TOPOLOGY_KEY}'")

    topology = model_json[ARTIFACT_MODEL_TOPOLOGY_KEY]

    if ARTIFACT_WEIGHTS_MANIFEST_KEY not in model_json:
        raise ValueError(f'{ARTIFACT_MODEL_JSON_FILE_NAME} is missing key '
                         f"'{ARTIFACT_WEIGHTS_MANIFEST_KEY}'")

    weights_manifest = model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY]
    weight_list = read_weights(weights_manifest, base_path, flatten=True)

    graph_def = _convert_graph_def(topology)
    name, data = common.TFJS_NAME_KEY, common.TFJS_DATA_KEY
    weight_dict = dict((weight[name], weight[data]) for weight in weight_list)
    graph_def, weight_modifiers = _replace_unsupported_operations(graph_def)

    return _create_graph(graph_def, weight_dict, weight_modifiers)
コード例 #6
0
ファイル: api.py プロジェクト: dallarosa/tfjs-to-tf
def _convert_graph_model_to_graph(model_json, base_path):
    """
    Convert TFJS JSON model to TF Graph

    Args:
        model_json: JSON dict from TFJS model file
        base_path:  Path to the model file (where to find the model weights)

    Returns:
        TF Graph for inference or saving
    """
    if not tfjs_common.ARTIFACT_MODEL_TOPOLOGY_KEY in model_json:
        raise ValueError("model_json is missing key '{}'".format(
            tfjs_common.ARTIFACT_MODEL_TOPOLOGY_KEY))

    topology = model_json[tfjs_common.ARTIFACT_MODEL_TOPOLOGY_KEY]

    if not tfjs_common.ARTIFACT_WEIGHTS_MANIFEST_KEY in model_json:
        raise ValueError("model_json is missing key '{}'".format(
            tfjs_common.ARTIFACT_WEIGHTS_MANIFEST_KEY))

    weights_manifest = model_json[tfjs_common.ARTIFACT_WEIGHTS_MANIFEST_KEY]
    weight_list = read_weights(weights_manifest, base_path, flatten=True)

    graph_def = _convert_graph_def(topology)
    weight_dict = _convert_weight_list_to_dict(weight_list)

    return _create_graph(graph_def, weight_dict)
コード例 #7
0
    def testReadOneGroupWithShards(self):
        groups = [[{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        self.assertTrue(
            np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
コード例 #8
0
    def testReadOneGroup(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        self.assertTrue(
            np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
コード例 #9
0
    def testReadQuantizedWeights(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([0, 1, 2, 3], 'float32')
        }]]

        manifest_json = write_weights.write_weights(
            groups, self._tmp_dir, quantization_dtype=np.uint8)
        manifest = json.loads(manifest_json)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        self.assertTrue(
            np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
コード例 #10
0
ファイル: read_weights_test.py プロジェクト: caisq/tfjs-1
  def testReadStringScalar(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array(u'abc'.encode('utf-8'), 'object')
        }]
    ]

    manifest = write_weights.write_weights(groups, self._tmp_dir)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    np.testing.assert_array_equal(read_output[0][0]['data'],
                                  np.array(u'abc'.encode('utf-8'), 'object'))
コード例 #11
0
ファイル: read_weights_test.py プロジェクト: caisq/tfjs-1
  def testReadBoolWeights(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([True, False, True], 'bool')
        }]
    ]

    manifest = write_weights.write_weights(groups, self._tmp_dir)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    np.testing.assert_array_equal(read_output[0][0]['data'],
                                  np.array([True, False, True], 'bool'))
コード例 #12
0
  def testReadOneGroupFlattened(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(groups, self._tmp_dir)
    manifest = json.loads(manifest_json)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(
        manifest, self._tmp_dir, flatten=True)
    self.assertEqual(1, len(read_output))
    self.assertEqual('weight1', read_output[0]['name'])
    self.assertTrue(np.allclose(groups[0][0]['data'], read_output[0]['data']))
コード例 #13
0
    def testReadOneGroupString(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([['test', 'a'], ['b', 'c']], 'object')
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        np.testing.assert_array_equal(
            read_output[0][0]['data'],
            np.array([[u'test'.encode('utf-8'), u'a'.encode('utf-8')],
                      [u'b'.encode('utf-8'), u'c'.encode('utf-8')]], 'object'))
コード例 #14
0
  def testReadOneGroupWithShards(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]
    ]

    manifest_json = write_weights.write_weights(groups, self._tmp_dir)
    manifest = json.loads(manifest_json)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    self.assertTrue(
        np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
コード例 #15
0
ファイル: api.py プロジェクト: PowerOlive/tfjs-to-tf
def _convert_graph_model_to_graph(
    model_json: Dict[str, Any],
    base_path: str,
    compat_mode: CompatMode = CompatMode.NONE
) -> Tuple[tf.Graph, util.SignatureDef]:
    """
    Convert TFJS JSON model to TF Graph

    Args:
        model_json: JSON dict from TFJS model file
        base_path:  Path to the model file (where to find the model weights)
        compat_mode: Compatibility mode for model conversion and optimisation

    Returns:
        Tuple of TF Graph for inference or saving and TF signature definition
    """
    if ARTIFACT_MODEL_TOPOLOGY_KEY not in model_json:
        raise ValueError(
            f"model_json is missing key '{ARTIFACT_MODEL_TOPOLOGY_KEY}'")
    model_format = model_json[FORMAT_KEY] if FORMAT_KEY in model_json else ''
    if model_format != TFJS_GRAPH_MODEL_FORMAT:
        raise ModelFormatError(f"unsupported model format: '{model_format}'",
                               model_format)

    topology = model_json[ARTIFACT_MODEL_TOPOLOGY_KEY]

    if ARTIFACT_WEIGHTS_MANIFEST_KEY not in model_json:
        raise ValueError(f'{ARTIFACT_MODEL_JSON_FILE_NAME} is missing key '
                         f"'{ARTIFACT_WEIGHTS_MANIFEST_KEY}'")

    weights_manifest = model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY]
    weight_list = read_weights(weights_manifest, base_path, flatten=True)

    graph_def = _convert_graph_def(topology)
    name, data = common.TFJS_NAME_KEY, common.TFJS_DATA_KEY
    weight_dict = dict((weight[name], weight[data]) for weight in weight_list)
    graph_def, weight_modifiers = _replace_unsupported_operations(
        graph_def, compat_mode)
    if compat_mode == CompatMode.TFJS:
        graph_def = convert_int64_to_int32(graph_def)
    graph = _create_graph(graph_def, weight_dict, weight_modifiers)
    signature_def = _extract_signature_def(model_json) or util.infer_signature(
        graph)
    return _set_signature_dtypes(graph, signature_def)
コード例 #16
0
  def testReadQuantizedWeights(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([0, 1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, self._tmp_dir, quantization_dtype=np.uint8)
    manifest = json.loads(manifest_json)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    self.assertTrue(
        np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
コード例 #17
0
ファイル: read_weights_test.py プロジェクト: caisq/tfjs-1
  def testReadFloat16QuantizedWeights(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([0, 1, 2, 3], 'float32')
        }]
    ]

    manifest = write_weights.write_weights(
        groups, self._tmp_dir, quantization_dtype_map={'float16': '*'})

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    self.assertEqual(read_output[0][0]['data'].dtype, np.float32)
    self.assertTrue(
        np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
コード例 #18
0
ファイル: read_weights_test.py プロジェクト: zakir2k/tfjs
  def testReadEastAsianStringUnicodeAndEncoded(self):
    # Each string tensor uses different encoding.
    groups = [
        [{
            'name': 'weight1',
            # Decoded.
            'data': np.array([u'语言处理'], 'object')
        }, {
            'name': 'weight2',
            # Encoded as utf-16.
            'data': np.array([u'语言处理'.encode('utf-16')], 'object')
        }, {
            'name': 'weight3',
            # Encoded as utf-8.
            'data': np.array([u'语言处理'.encode('utf-8')], 'object')
        }]
    ]

    manifest = write_weights.write_weights(groups, self._tmp_dir)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    group = read_output[0]
    self.assertEqual(3, len(group))

    weight1 = group[0]
    self.assertEqual('weight1', weight1['name'])
    np.testing.assert_array_equal(
        weight1['data'],
        np.array([u'语言处理'.encode('utf-8')], 'object'))

    weight2 = group[1]
    self.assertEqual('weight2', weight2['name'])
    np.testing.assert_array_equal(
        weight2['data'],
        np.array([u'语言处理'.encode('utf-16')], 'object'))

    weight3 = group[2]
    self.assertEqual('weight3', weight3['name'])
    np.testing.assert_array_equal(
        weight3['data'],
        np.array([u'语言处理'.encode('utf-8')], 'object'))
コード例 #19
0
ファイル: read_weights_test.py プロジェクト: zakir2k/tfjs
  def testReadOneGroupWithInt32DataFlattened(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }, {
            'name': 'weight2',
            'data': np.array([10, 20, 30], 'int32')
        }]
    ]

    manifest = write_weights.write_weights(groups, self._tmp_dir)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(
        manifest, self._tmp_dir, flatten=True)
    self.assertEqual(2, len(read_output))
    self.assertEqual('weight1', read_output[0]['name'])
    self.assertTrue(np.allclose(groups[0][0]['data'], read_output[0]['data']))
    self.assertEqual('weight2', read_output[1]['name'])
    self.assertTrue(np.allclose(groups[0][1]['data'], read_output[1]['data']))
コード例 #20
0
def _convert_graph_model_to_graph(
        model_json: Dict[str, Any],
        base_path: str,
        compat_mode: bool = False) -> Tuple[tf.Graph, util.SignatureDef]:
    """
    Convert TFJS JSON model to TF Graph

    Args:
        model_json: JSON dict from TFJS model file
        base_path:  Path to the model file (where to find the model weights)
        compat_mode: True, if only TFJS datatypes should be used

    Returns:
        Tuple of TF Graph for inference or saving and TF signature definition
    """
    if ARTIFACT_MODEL_TOPOLOGY_KEY not in model_json:
        raise ValueError(
            f"model_json is missing key '{ARTIFACT_MODEL_TOPOLOGY_KEY}'")

    topology = model_json[ARTIFACT_MODEL_TOPOLOGY_KEY]

    if ARTIFACT_WEIGHTS_MANIFEST_KEY not in model_json:
        raise ValueError(f'{ARTIFACT_MODEL_JSON_FILE_NAME} is missing key '
                         f"'{ARTIFACT_WEIGHTS_MANIFEST_KEY}'")

    weights_manifest = model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY]
    weight_list = read_weights(weights_manifest, base_path, flatten=True)

    graph_def = _convert_graph_def(topology)
    name, data = common.TFJS_NAME_KEY, common.TFJS_DATA_KEY
    weight_dict = dict((weight[name], weight[data]) for weight in weight_list)
    graph_def, weight_modifiers = _replace_unsupported_operations(graph_def)
    if compat_mode:
        graph_def = convert_int64_to_int32(graph_def)
    graph = _create_graph(graph_def, weight_dict, weight_modifiers)
    signature_def = _extract_signature_def(model_json) or util.infer_signature(
        graph)
    return _set_signature_dtypes(graph, signature_def)
コード例 #21
0
def load_keras_model(config_json_path,
                     weights_path_prefix=None,
                     weights_data_buffers=None,
                     load_weights=True,
                     use_unique_name_scope=False):
    """Load a Keras Model from TensorFlow.js-format artifacts.

  Args:
    config_json_path: Path to the TensorFlow.js-format JSON file that includes
      the model topology and weights manifest.
    weights_path_prefix: Optional path prefix for the weights files.
      If not specified (`None`), will assume the prefix is the same directory
      as the dirname of `config_json_path`.
    weights_data_buffers: A buffer of a `list` of buffers containing the weight
      values concatenated and sharded in the order as specified by the
      weights manifest at `config_json_path`. This argument is mutually
      exclusive with `weights_path_prefix`.
    load_weights: Whether the weights are to be loaded according
      to the weights manifest at `config_json_path`. Default: `True`.
    use_unique_name_scope: Use a unique ID as the name scope for the loaded
      model. This may facilitate loading of multiple Keras models in the
      same TensorFlow Graph or Session context. Default: `False`.

  Returns:
    The loaded instance of `keras.Model`.

  Raises:
    TypeError, if the format of the JSON content of `config_json_path` has an
      invalid format.
    KeyError, if required keys do not exist in the JSON content of
      `config_json_path`.
    ValueError, if both `weights_data_buffers` and `weights_path_prefix` are
      provided.
  """
    with open(config_json_path, 'rt') as f:
        model_and_weights_manifest = json.load(f)

    if not isinstance(model_and_weights_manifest, dict):
        raise TypeError(
            'The JSON content of %s is required to be a `dict`, but found %s' %
            (config_json_path, type(model_and_weights_manifest)))
    if 'modelTopology' not in model_and_weights_manifest:
        raise KeyError(
            'Field "modelTopology" is missing from the JSON content in %s' %
            config_json_path)

    model_json = model_and_weights_manifest['modelTopology']

    if 'model_config' in model_json:
        model_json = model_json['model_config']
    unique_name_scope = uuid.uuid4().hex if use_unique_name_scope else None
    with tf.name_scope(unique_name_scope):
        model = keras.models.model_from_json(json.dumps(model_json))

    if load_weights:
        if 'weightsManifest' not in model_and_weights_manifest:
            raise KeyError(
                'Field "weightsManifest" is missing from the JSON content in %s'
                % config_json_path)
        weights_manifest = model_and_weights_manifest['weightsManifest']

        if weights_data_buffers:
            if weights_path_prefix:
                raise ValueError(
                    'The arguments weights_data_buffers and weights_path_prefix are '
                    'mutually exclusive and should not be both specified.')
            weight_entries = read_weights.decode_weights(weights_manifest,
                                                         weights_data_buffers,
                                                         flatten=True)
        else:
            weight_names = [
                keras_h5_conversion.normalize_weight_name(
                    w.name[len(unique_name_scope) +
                           1:]) if use_unique_name_scope else
                keras_h5_conversion.normalize_weight_name(w.name[:-2])
                for w in model.weights
            ]

            if not weights_path_prefix:
                weights_path_prefix = os.path.dirname(config_json_path)
            if not os.path.isdir(weights_path_prefix):
                raise ValueError(
                    'Weights path prefix is not an existing directory: %s' %
                    weights_path_prefix)

            weight_entries = read_weights.read_weights(weights_manifest,
                                                       weights_path_prefix,
                                                       flatten=True)
        weights_dict = dict()
        for weight_entry in weight_entries:
            weights_dict[weight_entry['name']] = weight_entry['data']

        weights_list = []
        for weight_name in weight_names:
            weights_list.append(weights_dict[weight_name])
        model.set_weights(weights_list)

    return model
コード例 #22
0
def load_keras_model(config_json_path,
                     weights_path_prefix=None,
                     weights_data_buffers=None,
                     load_weights=True,
                     use_unique_name_scope=False):
  """Load a Keras Model from TensorFlow.js-format artifacts from file system

  Args:
    config_json_path: Path to the TensorFlow.js-format JSON file that includes
      the model topology and weights manifest.
    weights_path_prefix: Optional path prefix for the weights files.
      If not specified (`None`), will assume the prefix is the same directory
      as the dirname of `config_json_path`.
    weights_data_buffers: A buffer of a `list` of buffers containing the weight
      values concatenated and sharded in the order as specified by the
      weights manifest at `config_json_path`. This argument is mutually
      exclusive with `weights_path_prefix`.
    load_weights: Whether the weights are to be loaded according
      to the weights manifest at `config_json_path`. Default: `True`.
    use_unique_name_scope: Use a unique ID as the name scope for the loaded
      model. This may facilitate loading of multiple Keras models in the
      same TensorFlow Graph or Session context. Default: `False`.

  Returns:
    The loaded instance of `keras.Model`.

  Raises:
    TypeError, if the format of the JSON content of `config_json_path` has an
      invalid format.
    KeyError, if required keys do not exist in the JSON content of
      `config_json_path`.
    ValueError, if both `weights_data_buffers` and `weights_path_prefix` are
      provided.
  """
  if weights_data_buffers and weights_path_prefix:
    raise ValueError(
        'The arguments weights_data_buffers and weights_path_prefix are '
        'mutually exclusive and should not be both specified.')

  with open(config_json_path, 'rt') as f:
    config_json = json.load(f)
    _check_config_json(config_json)

  weight_entries = None
  if load_weights:
    weights_manifest = _get_weights_manifest_from_config_json(config_json)

    if not weights_data_buffers and not weights_path_prefix:
      weights_path_prefix = os.path.dirname(
          os.path.realpath(config_json_path))
    if not os.path.isdir(weights_path_prefix):
      raise ValueError(
          'Weights path prefix is not an existing directory: %s' %
          weights_path_prefix)
    if weights_path_prefix:
      weight_entries = read_weights.read_weights(weights_manifest,
                                                 weights_path_prefix,
                                                 flatten=True)
    else:
      weight_entries = read_weights.decode_weights(weights_manifest,
                                                   weights_data_buffers,
                                                   flatten=True)

  return _deserialize_keras_model(config_json['modelTopology'],
                                  weight_entries=weight_entries,
                                  use_unique_name_scope=use_unique_name_scope)
コード例 #23
0
def load_keras_model(config_json_path,
                     weights_path_prefix=None,
                     weights_data_buffers=None,
                     load_weights=True,
                     use_unique_name_scope=False):
  """Load a Keras Model from TensorFlow.js-format artifacts from file system

  Args:
    config_json_path: Path to the TensorFlow.js-format JSON file that includes
      the model topology and weights manifest.
    weights_path_prefix: Optional path prefix for the weights files.
      If not specified (`None`), will assume the prefix is the same directory
      as the dirname of `config_json_path`.
    weights_data_buffers: A buffer of a `list` of buffers containing the weight
      values concatenated and sharded in the order as specified by the
      weights manifest at `config_json_path`. This argument is mutually
      exclusive with `weights_path_prefix`.
    load_weights: Whether the weights are to be loaded according
      to the weights manifest at `config_json_path`. Default: `True`.
    use_unique_name_scope: Use a unique ID as the name scope for the loaded
      model. This may facilitate loading of multiple Keras models in the
      same TensorFlow Graph or Session context. Default: `False`.

  Returns:
    The loaded instance of `tf.keras.Model`.

  Raises:
    TypeError, if the format of the JSON content of `config_json_path` has an
      invalid format.
    KeyError, if required keys do not exist in the JSON content of
      `config_json_path`.
    ValueError, if both `weights_data_buffers` and `weights_path_prefix` are
      provided.
  """
  if weights_data_buffers and weights_path_prefix:
    raise ValueError(
        'The arguments weights_data_buffers and weights_path_prefix are '
        'mutually exclusive and should not be both specified.')

  with open(config_json_path, 'rt') as f:
    config_json = json.load(f)
    _check_config_json(config_json)

  weight_entries = None
  if load_weights:
    weights_manifest = _get_weights_manifest_from_config_json(config_json)

    if not weights_data_buffers and not weights_path_prefix:
      weights_path_prefix = os.path.dirname(
          os.path.realpath(config_json_path))
    if not os.path.isdir(weights_path_prefix):
      raise ValueError(
          'Weights path prefix is not an existing directory: %s' %
          weights_path_prefix)
    if weights_path_prefix:
      weight_entries = read_weights.read_weights(weights_manifest,
                                                 weights_path_prefix,
                                                 flatten=True)
    else:
      weight_entries = read_weights.decode_weights(weights_manifest,
                                                   weights_data_buffers,
                                                   flatten=True)

  return _deserialize_keras_model(config_json['modelTopology'],
                                  weight_entries=weight_entries,
                                  use_unique_name_scope=use_unique_name_scope)