Exemple #1
0
  def testMapLayerExactlyOneFallthrough(self):
    names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
    quantization_dtype_map = {'float16': True, 'uint8': True}

    with self.assertRaises(ValueError):
      quantization.map_layers_to_quantization_dtype(
          names, quantization_dtype_map)
Exemple #2
0
  def testMapLayerConflictingMap(self):
    names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
    quantization_dtype_map = {'float16': ['conv/0/*'], 'uint8': ['conv/0/bias']}

    with self.assertRaises(ValueError):
      quantization.map_layers_to_quantization_dtype(
          names, quantization_dtype_map)
Exemple #3
0
  def testMapLayerNoDtypeMap(self):
    names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
    quantization_dtype_map = {}
    dtype_map = quantization.map_layers_to_quantization_dtype(
        names, quantization_dtype_map)

    self.assertDictEqual(dtype_map, {})
Exemple #4
0
  def testMapLayerFallthrough(self):
    names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
    quantization_dtype_map = {'float16': ['conv/0/*'], 'uint8': True}
    dtype_map = quantization.map_layers_to_quantization_dtype(
        names, quantization_dtype_map)

    self.assertDictEqual(dtype_map, {
        'conv/0/weight': np.float16,
        'conv/0/bias': np.float16,
        'conv/1/weight': np.uint8,
        'conv/1/bias': np.uint8
    })
Exemple #5
0
  def testMapLayerStringToList(self):
    names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
    quantization_dtype_map = {'float16': '*'}


    dtype_map = quantization.map_layers_to_quantization_dtype(
        names, quantization_dtype_map)

    self.assertDictEqual(dtype_map, {
        'conv/0/weight': np.float16,
        'conv/0/bias': np.float16,
        'conv/1/weight': np.float16,
        'conv/1/bias': np.float16
    })
Exemple #6
0
def write_weights(weight_groups,
                  write_dir,
                  shard_size_bytes=1024 * 1024 * 4,
                  write_manifest=True,
                  quantization_dtype_map=None):
    """Writes weights to a binary format on disk for ingestion by JavaScript.

    Weights are organized into groups. When writing to disk, the bytes from all
    weights in each group are concatenated together and then split into shards
    (default is 4MB). This means that large weights (> shard_size) get sharded
    and small weights (< shard_size) will be packed. If the bytes can't be split
    evenly into shards, there will be a leftover shard that is smaller than the
    shard size.

    Weights are optionally quantized to either 8 or 16 bits for compression,
    which is enabled via the `quantization_dtype_map`.

    Args:
      weight_groups: An list of groups. Each group is an array of weight
        entries. Each entry is a dict that maps a unique name to a numpy array,
        for example:
        entry = {
          'name': 'weight1',
          'data': np.array([1, 2, 3], 'float32')
        }

        Weights groups would then look like:
        weight_groups = [
          [group_0_entry1, group_0_entry2],
          [group_1_entry1, group_1_entry2],
        ]

        The 'name' must be unique across all groups and all entries. The 'data'
        field must be a numpy ndarray.
      write_dir: A directory to write the files to.
      shard_size_bytes: The size of shards in bytes. Defaults to 4MB, which is
        the max file size for caching for all major browsers.
      write_manifest: Whether to write the manifest JSON to disk. Defaults to
        True.
      quantization_dtype_map: (Optional) A mapping from dtype
        (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
        supports wildcard substitution.
    Returns:
      The weights manifest JSON dict.

      An example manifest with 2 groups, 2 weights, and each weight sharded
      into 2:

      The manifest JSON looks like the following:
      [{
        'paths': ['group1-shard1of2', 'group1-shard2of2'],
        'weights': [{
          'name': 'weight1',
          'shape': [1000, 1000],
          'dtype': 'float32'
        }]
      }, {
        'paths': ['group2-shard1of2', 'group2-shard2of2'],
        'weights': [{
          'name': 'weight2',
          'shape': [2000, 2000],
          'dtype': 'float32'
        }]
      }]
      or, if quantization is used:
      [{
        'paths': ['group1-shard1of2', 'group1-shard2of2'],
        'weights': [{
          'name': 'weight1',
          'shape': [1000, 1000],
          'dtype': 'float32'
          'quantization': {'min': -0.1, 'scale': 0.01, 'dtype': 'uint8'}
        }]
      }, {
        'paths': ['group2-shard1of2', 'group2-shard2of2'],
        'weights': [{
          'name': 'weight2',
          'shape': [2000, 2000],
          'dtype': 'float32',
          'quantization': {'dtype': 'float16'}
        }]
      }]
  """
    _assert_weight_groups_valid(weight_groups)
    _assert_shard_size_bytes_valid(shard_size_bytes)
    _assert_no_duplicate_weight_names(weight_groups)

    manifest = []

    for group_index, group in enumerate(weight_groups):
        for e in group:
            _auto_convert_weight_entry(e)
        names = [entry['name'] for entry in group]
        quantization_dtype = quantization.map_layers_to_quantization_dtype(
            names, quantization_dtype_map)

        group = [
            _quantize_entry(e, quantization_dtype[e['name']])
            if e['name'] in quantization_dtype else e for e in group
        ]
        group_bytes, total_bytes, _ = _stack_group_bytes(group)

        shard_filenames = _shard_group_bytes_to_disk(write_dir, group_index,
                                                     group_bytes, total_bytes,
                                                     shard_size_bytes)

        weights_entries = _get_weights_manifest_for_group(group)
        manifest_entry = {'paths': shard_filenames, 'weights': weights_entries}
        manifest.append(manifest_entry)

    if write_manifest:
        manifest_path = os.path.join(write_dir, 'weights_manifest.json')
        with tf.io.gfile.GFile(manifest_path, 'wb') as f:
            f.write(json.dumps(manifest).encode())

    return manifest