def dump_checkpoint(
    checkpoint_file, output_dir, shard_mb=4, remove_variables_regex=None,
    quantization_dtype=None):
  reader = tf.train.NewCheckpointReader(checkpoint_file)
  var_to_shape_map = reader.get_variable_to_shape_map()

  remove_variables_regex_re = (
      re.compile(remove_variables_regex) if remove_variables_regex else None)

  entries = []
  for var_name, shape in var_to_shape_map.items():
    if (remove_variables_regex_re and remove_variables_regex_re.match(var_name)
        or var_name == 'global_step'):
      print('Ignoring Regex Match: ' + var_name)
      continue
    if not shape:
      print('Ignoring Scalar: ' + var_name)
      continue

    tensor = reader.get_tensor(var_name)
    entries.append({'name': var_name, 'data': tensor})
    print('Dumping %s (%r)' %  (var_name, shape))

  write_weights(
    [entries],
    output_dir,
    write_manifest=True,
    quantization_dtype=quantization_dtype,
    shard_size_bytes=shard_mb * 1024 * 1024)
Example #2
0
def extract_weights(graph_def, output_graph):
    """Takes a Python GraphDef object and extract the weights.

  Args:
    graph_def: tf.GraphDef tensorflow GraphDef proto object, which represents
      the model topology
  """
    constants = [node for node in graph_def.node if node.op == 'Const']
    # removed the conditional inputs for constants
    for const in constants:
        del const.input[:]

    print('Writing weight file ' + output_graph + '...')
    const_manifest = []
    path = os.path.dirname(output_graph)

    graph = tf.Graph()
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name='')
        for const in constants:
            tensor = graph.get_tensor_by_name(const.name + ':0')
            value = tensor.eval(session=sess)
            if not isinstance(value, np.ndarray):
                value = np.array(value)

            const_manifest.append({'name': const.name, 'data': value})

            # Remove the binary array from tensor and save it to the external file.
            const.attr["value"].tensor.ClearField('tensor_content')

    write_weights.write_weights([const_manifest], path)

    file_io.atomic_write_string_to_file(os.path.abspath(output_graph),
                                        graph_def.SerializeToString())
  def test_non_grouped_weights_throws(self):
    groups = [{
        'name': 'weight1',
        'data': np.array([1, 2, 3], 'float32')
    }]

    with self.assertRaises(Exception):
      write_weights.write_weights(groups, TMP_DIR)
    def test_bad_numpy_array_dtype_throws(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float64')
        }]]

        with self.assertRaises(Exception):
            write_weights.write_weights(groups, TMP_DIR)
Example #5
0
  def test_non_grouped_weights_throws(self):
    groups = [{
        'name': 'weight1',
        'data': np.array([1, 2, 3], 'float32')
    }]

    with self.assertRaises(Exception):
      write_weights.write_weights(groups, TMP_DIR)
    def test_bad_weights_entry_throws_no_data(self):
        groups = [[{
            'name': 'weight1',
            'nodata': np.array([1, 2, 3], 'float32')
        }]]

        with self.assertRaises(Exception):
            write_weights.write_weights(groups, TMP_DIR)
  def test_bad_weights_entry_throws_no_data(self):
    groups = [
        [{
            'name': 'weight1',
            'nodata': np.array([1, 2, 3], 'float32')
        }]
    ]

    with self.assertRaises(Exception):
      write_weights.write_weights(groups, TMP_DIR)
  def test_bad_numpy_array_dtype_throws(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float64')
        }]
    ]

    with self.assertRaises(Exception):
      write_weights.write_weights(groups, TMP_DIR)
    def testReadWeightsWithIncorrectTypeInWeightsManifestRaisesError(self):
        groups = [[{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]]

        write_weights.write_weights(groups, self._tmp_dir)

        with self.assertRaises(ValueError):
            read_weights.read_weights(groups[0][0], self._tmp_dir)
    def test_duplicate_weight_name_throws(self):
        groups = [[{
            'name': 'duplicate',
            'data': np.array([1, 2, 3], 'float32')
        }], [{
            'name': 'duplicate',
            'data': np.array([4, 5, 6], 'float32')
        }]]

        with self.assertRaises(Exception):
            write_weights.write_weights(groups, TMP_DIR)
  def testReadWeightsWithIncorrectTypeInWeightsManifestRaisesError(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]
    ]

    write_weights.write_weights(groups, self._tmp_dir)

    with self.assertRaises(ValueError):
      read_weights.read_weights(groups[0][0], self._tmp_dir)
  def test_duplicate_weight_name_throws(self):
    groups = [
        [{
            'name': 'duplicate',
            'data': np.array([1, 2, 3], 'float32')
        }], [{
            'name': 'duplicate',
            'data': np.array([4, 5, 6], 'float32')
        }]
    ]

    with self.assertRaises(Exception):
      write_weights.write_weights(groups, TMP_DIR)
Example #13
0
    def test_1_group_1_weight_bool(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([True, False, True], 'bool')
        }]]

        manifest = write_weights.write_weights(groups,
                                               TMP_DIR,
                                               shard_size_bytes=4 * 4)

        self.assertTrue(
            os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
            'weights_manifest.json does not exist')

        self.assertEqual(manifest, [{
            'paths': ['group1-shard1of1.bin'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'bool'
            }]
        }])

        weights_path = os.path.join(TMP_DIR, 'group1-shard1of1.bin')
        weight1 = np.fromfile(weights_path, 'bool')
        np.testing.assert_array_equal(weight1,
                                      np.array([True, False, True], 'bool'))
Example #14
0
    def test_1_group_2_packed_sharded_multi_dtype(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'int32')
        }, {
            'name': 'weight2',
            'data': np.array([True, False, False, True], 'bool')
        }, {
            'name': 'weight3',
            'data': np.array([4.1, 5.1], 'float32')
        }]]

        # Shard size is smaller than the sum of the weights so they get packed and
        # then sharded. The two buffers will get sharded into 3 files, with shapes
        # [2], [2], and [1]. The second shard is a mixed dtype.
        manifest = write_weights.write_weights(groups,
                                               TMP_DIR,
                                               shard_size_bytes=2 * 4)

        self.assertTrue(
            os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
            'weights_manifest.json does not exist')
        self.assertEqual(manifest, [{
            'paths': [
                'group1-shard1of3.bin', 'group1-shard2of3.bin',
                'group1-shard3of3.bin'
            ],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'int32'
            }, {
                'name': 'weight2',
                'shape': [4],
                'dtype': 'bool'
            }, {
                'name': 'weight3',
                'shape': [2],
                'dtype': 'float32'
            }]
        }])

        shard_1_path = os.path.join(TMP_DIR, 'group1-shard1of3.bin')
        shard_1 = np.fromfile(shard_1_path, 'int32')
        np.testing.assert_array_equal(shard_1, np.array([1, 2], 'int32'))

        # Shard 2 has a mixed dtype so we parse the bytes directly.
        shard_2_path = os.path.join(TMP_DIR, 'group1-shard2of3.bin')
        with open(shard_2_path, 'rb') as f:
            shard_2_bytes = f.read()
        self.assertEqual(len(shard_2_bytes), 8)
        shard_2_int = np.frombuffer(shard_2_bytes[:4], 'int32')
        np.testing.assert_array_equal(shard_2_int, np.array([3], 'int32'))
        shard_2_bool = np.frombuffer(shard_2_bytes[4:], 'bool')
        np.testing.assert_array_equal(
            shard_2_bool, np.array([True, False, False, True], 'bool'))

        shard_3_path = os.path.join(TMP_DIR, 'group1-shard3of3.bin')
        shard_3 = np.fromfile(shard_3_path, 'float32')
        np.testing.assert_array_equal(shard_3, np.array([4.1, 5.1], 'float32'))
Example #15
0
  def test_1_group_1_weight(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=4 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')

    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }]
        }])

    weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
    weight1 = np.fromfile(weights_path, 'float32')
    np.testing.assert_array_equal(weight1, np.array([1, 2, 3], 'float32'))
  def test_1_group_1_weight_bool(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([True, False, True], 'bool')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=4 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')

    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'bool'
            }]
        }])

    weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
    weight1 = np.fromfile(weights_path, 'bool')
    np.testing.assert_array_equal(
        weight1, np.array([True, False, True], 'bool'))
Example #17
0
  def test_no_write_manfest(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, write_manifest=False)
    manifest = json.loads(manifest_json)

    self.assertFalse(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json exists, but expected not to exist')
    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }]
        }])

    weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
    weight1 = np.fromfile(weights_path, 'float32')
    np.testing.assert_array_equal(weight1, np.array([1, 2, 3], 'float32'))
  def test_1_group_1_weight_sharded(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]
    # Shard size is smaller than the size of the array so it gets split between
    # multiple files.
    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=2 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')

    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of2', 'group1-shard2of2'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }]
        }])

    shard_1_path = os.path.join(TMP_DIR, 'group1-shard1of2')
    shard_1 = np.fromfile(shard_1_path, 'float32')
    np.testing.assert_array_equal(shard_1, np.array([1, 2], 'float32'))

    shard_2_path = os.path.join(TMP_DIR, 'group1-shard2of2')
    shard_2 = np.fromfile(shard_2_path, 'float32')
    np.testing.assert_array_equal(shard_2, np.array([3], 'float32'))
    def test_1_group_2_weights_packed(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }, {
            'name': 'weight2',
            'data': np.array([4, 5], 'float32')
        }]]

        # Shard size is larger than the sum of the two weights so they get packed.
        manifest_json = write_weights.write_weights(groups,
                                                    TMP_DIR,
                                                    shard_size_bytes=8 * 4)
        manifest = json.loads(manifest_json)

        self.assertTrue(
            os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
            'weights_manifest.json does not exist')
        self.assertEqual(manifest, [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }, {
                'name': 'weight2',
                'shape': [2],
                'dtype': 'float32'
            }]
        }])

        weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
        weights = np.fromfile(weights_path, 'float32')
        np.testing.assert_array_equal(weights,
                                      np.array([1, 2, 3, 4, 5], 'float32'))
  def test_no_write_manfest(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, write_manifest=False)
    manifest = json.loads(manifest_json)

    self.assertFalse(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json exists, but expected not to exist')
    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }]
        }])

    weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
    weight1 = np.fromfile(weights_path, 'float32')
    np.testing.assert_array_equal(weight1, np.array([1, 2, 3], 'float32'))
Example #21
0
  def testReadCyrillicStringUnicodeAndEncoded(self):
    groups = [
        [{
            'name': 'weight1',
            # String is stored as unicode.
            'data': np.array([u'здраво'], 'object')
        }, {
            'name': 'weight2',
            # String is stored encoded.
            'data': np.array([u'поздрав'.encode('utf-8')], 'object')
        }]
    ]

    manifest = write_weights.write_weights(groups, self._tmp_dir)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    group = read_output[0]
    self.assertEqual(2, len(group))

    weight1 = group[0]
    self.assertEqual('weight1', weight1['name'])
    np.testing.assert_array_equal(
        weight1['data'],
        np.array([u'здраво'.encode('utf-8')], 'object'))

    weight2 = group[1]
    self.assertEqual('weight2', weight2['name'])
    np.testing.assert_array_equal(
        weight2['data'],
        np.array([u'поздрав'.encode('utf-8')], 'object'))
Example #22
0
    def test_1_group_1_weight_string_empty(self):
        groups = [[{'name': 'weight1', 'data': np.array([''], 'object')}]]

        manifest = write_weights.write_weights(groups,
                                               TMP_DIR,
                                               shard_size_bytes=4 * 1024 *
                                               1024)

        self.assertTrue(
            os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
            'weights_manifest.json does not exist')

        self.assertEqual(manifest, [{
            'paths': ['group1-shard1of1.bin'],
            'weights': [{
                'name': 'weight1',
                'shape': [1],
                'dtype': 'string'
            }]
        }])

        weights_path = os.path.join(TMP_DIR, 'group1-shard1of1.bin')
        with open(weights_path, 'rb') as f:
            weight_bytes = f.read()
            self.assertEqual(len(weight_bytes), 4)
            size = np.frombuffer(weight_bytes[:4], 'uint32')[0]
            self.assertEqual(size, 0)  # Empty string.
Example #23
0
def write_artifacts(topology,
                    weights,
                    output_graph,
                    tf_version,
                    quantization_dtype=None):
  """Writes weights and topology to the output_dir.

  If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

  Args:
    topology: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    weights: an array of weight groups (as defined in tfjs write_weights).
    output_graph: the output file name to hold all the contents.
    tf_version: Tensorflow version of the input graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
  """
  model_json = {
      common.FORMAT_KEY: common.TFJS_GRAPH_MODEL_FORMAT,
      # TODO(piyu): Add tensorflow version below by using `meta_info_def`.
      common.GENERATED_BY_KEY: tf_version,
      common.CONVERTED_BY_KEY: common.get_converted_by(),
  }

  model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
  weights_manifest = write_weights.write_weights(
      weights, os.path.dirname(output_graph), write_manifest=False,
      quantization_dtype=quantization_dtype)
  assert isinstance(weights_manifest, list)
  model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

  with open(output_graph, 'wt') as f:
    json.dump(model_json, f)
    def write_artifacts(self, topology, weights, output_dir):
        """Writes weights and topology to the output_dir.

    If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

    Args:
      topology: a JSON dictionary, representing the Keras config.
      weights: an array of weight groups (as defined in tfjs write_weights).
      output_dir: the directory to hold all the contents.
    """
        # TODO(cais, nielsene): This method should allow optional arguments of
        #   `write_weights.write_weights` (e.g., shard size) and forward them.
        # We write the topology after since write_weights makes no promises about
        # preserving directory contents.
        if os.path.isfile(output_dir):
            raise ValueError(
                'Path "%d" already exists as a file (not a directory).' %
                output_dir)

        model_json = {}

        model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
        weights_manifest = write_weights.write_weights(weights,
                                                       output_dir,
                                                       write_manifest=False)
        if not isinstance(weights_manifest, list):
            weights_manifest = json.loads(weights_manifest)
        assert isinstance(weights_manifest, list)
        model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

        model_json_path = os.path.join(output_dir,
                                       ARTIFACT_MODEL_JSON_FILE_NAME)
        with open(model_json_path, 'wt') as f:
            json.dump(model_json, f)
Example #25
0
  def test_1_group_1_weight_sharded(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]
    # Shard size is smaller than the size of the array so it gets split between
    # multiple files.
    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=2 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')

    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of2', 'group1-shard2of2'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }]
        }])

    shard_1_path = os.path.join(TMP_DIR, 'group1-shard1of2')
    shard_1 = np.fromfile(shard_1_path, 'float32')
    np.testing.assert_array_equal(shard_1, np.array([1, 2], 'float32'))

    shard_2_path = os.path.join(TMP_DIR, 'group1-shard2of2')
    shard_2 = np.fromfile(shard_2_path, 'float32')
    np.testing.assert_array_equal(shard_2, np.array([3], 'float32'))
    def testReadOneGroupEmptyStrings(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array(['', ''], 'object')
        }, {
            'name': 'weight2',
            'data': np.array([], 'object')
        }, {
            'name': 'weight3',
            'data': np.array([[]], 'object')
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        group = read_output[0]
        self.assertEqual(3, len(group))

        weight1 = group[0]
        self.assertEqual('weight1', weight1['name'])
        np.testing.assert_array_equal(
            weight1['data'],
            np.array([u''.encode('utf-8'), u''.encode('utf-8')], 'object'))

        weight2 = group[1]
        self.assertEqual('weight2', weight2['name'])
        np.testing.assert_array_equal(weight2['data'], np.array([], 'object'))

        weight3 = group[2]
        self.assertEqual('weight3', weight3['name'])
        np.testing.assert_array_equal(weight3['data'], np.array([[]],
                                                                'object'))
Example #27
0
def write_artifacts(topology,
                    weights,
                    output_graph,
                    quantization_dtype=None):
  """Writes weights and topology to the output_dir.

  If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

  Args:
    topology: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    weights: an array of weight groups (as defined in tfjs write_weights).
    output_graph: the output file name to hold all the contents.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
  """
  model_json = {}

  model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
  weights_manifest = write_weights.write_weights(
      weights, os.path.dirname(output_graph), write_manifest=False,
      quantization_dtype=quantization_dtype)
  assert isinstance(weights_manifest, list)
  model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

  with open(output_graph, 'wt') as f:
    json.dump(model_json, f)
def write_artifacts(topology,
                    weights,
                    output_graph,
                    tf_version,
                    signature_def,
                    quantization_dtype_map=None,
                    weight_shard_size_bytes=1024 * 1024 * 4,
                    initializer_graph_def=None,
                    metadata=None):
    """Writes weights and topology to the output_dir.

  If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

  Args:
    topology: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    weights: an array of weight groups (as defined in tfjs write_weights).
    output_graph: the output file name to hold all the contents.
    tf_version: Tensorflow version of the input graph.
    signature_def: the SignatureDef of the inference graph.
    quantization_dtype_map: A mapping from dtype
      (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
      supports wildcard substitution.
    weight_shard_size_bytes: Shard size (in bytes) of the weight files.
      The size of each weight file will be <= this value.
    initializer_graph_def: tf.GraphDef proto object for initializer graph.
    metadata: User defined metadata map.
  """
    model_json = {
        common.FORMAT_KEY:
        common.TFJS_GRAPH_MODEL_FORMAT,
        # TODO(piyu): Add tensorflow version below by using `meta_info_def`.
        common.GENERATED_BY_KEY:
        tf_version,
        common.CONVERTED_BY_KEY:
        common.get_converted_by(),
        common.SIGNATURE_KEY:
        MessageToDict(signature_def),
    }
    model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None

    if metadata:
        model_json[common.USER_DEFINED_METADATA_KEY] = metadata

    if initializer_graph_def and initializer_graph_def.node:
        model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict(
            initializer_graph_def)

    weights_manifest = write_weights.write_weights(
        weights,
        os.path.dirname(output_graph),
        write_manifest=False,
        quantization_dtype_map=quantization_dtype_map,
        shard_size_bytes=weight_shard_size_bytes)
    assert isinstance(weights_manifest, list)
    model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

    with tf.io.gfile.GFile(output_graph, 'w') as f:
        json.dump(model_json, f)
Example #29
0
def write_artifacts(topology,
                    weights,
                    output_dir,
                    quantization_dtype_map=None,
                    weight_shard_size_bytes=1024 * 1024 * 4,
                    metadata=None):
    """Writes weights and topology to the output_dir.

  If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

  Args:
    topology: a JSON dictionary, representing the Keras config.
    weights: an array of weight groups (as defined in tfjs write_weights).
    output_dir: the directory to hold all the contents.
    quantization_dtype_map: (Optional) A mapping from dtype
      (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
      supports wildcard substitution.
    weight_shard_size_bytes: Shard size (in bytes) of the weight files.
      The size of each weight file will be <= this value.
    metadata: User defined metadata map.
  """
    # TODO(cais, nielsene): This method should allow optional arguments of
    #   `write_weights.write_weights` (e.g., shard size) and forward them.
    # We write the topology after since write_weights makes no promises about
    # preserving directory contents.
    if not (isinstance(weight_shard_size_bytes, int)
            and weight_shard_size_bytes > 0):
        raise ValueError(
            'Expected weight_shard_size_bytes to be a positive integer, '
            'but got %s' % weight_shard_size_bytes)

    if os.path.isfile(output_dir):
        raise ValueError(
            'Path "%d" already exists as a file (not a directory).' %
            output_dir)

    model_json = {
        common.FORMAT_KEY: common.TFJS_LAYERS_MODEL_FORMAT,
        common.GENERATED_BY_KEY: _get_generated_by(topology),
        common.CONVERTED_BY_KEY: common.get_converted_by()
    }

    if metadata:
        model_json[common.USER_DEFINED_METADATA_KEY] = metadata

    model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
    weights_manifest = write_weights.write_weights(
        weights,
        output_dir,
        write_manifest=False,
        quantization_dtype_map=quantization_dtype_map,
        shard_size_bytes=weight_shard_size_bytes)
    assert isinstance(weights_manifest, list)
    model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

    model_json_path = os.path.join(output_dir,
                                   common.ARTIFACT_MODEL_JSON_FILE_NAME)
    with open(model_json_path, 'wt') as f:
        json.dump(model_json, f)
Example #30
0
def extract_weights(graph_def, output_graph, quantization_dtype=None):
    """Takes a Python GraphDef object and extract the weights.

  Args:
    graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    quantization_dtype: An optional numpy dtype to quantize weights to for
        compression. Only np.uint8 and np.uint16 are supported.
  """
    constants = [node for node in graph_def.node if node.op == 'Const']
    constInputs = {}
    # removed the conditional inputs for constants
    for const in constants:
        constInputs[const.name] = const.input[:]
        del const.input[:]

    print('Writing weight file ' + output_graph + '...')
    const_manifest = []
    path = os.path.dirname(output_graph)

    graph = tf.Graph()
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name='')
        for const in constants:
            tensor = graph.get_tensor_by_name(const.name + ':0')
            value = tensor.eval(session=sess)
            if not isinstance(value, np.ndarray):
                value = np.array(value)

            # Restore the conditional inputs
            const_manifest.append({'name': const.name, 'data': value})
            const.input[:] = constInputs[const.name]

            # Remove the binary array from tensor and save it to the external file.
            for field_name in CLEARED_TENSOR_FIELDS:
                const.attr["value"].tensor.ClearField(field_name)

    write_weights.write_weights([const_manifest],
                                path,
                                quantization_dtype=quantization_dtype)

    file_io.atomic_write_string_to_file(os.path.abspath(output_graph),
                                        graph_def.SerializeToString())
def extract_weights(graph_def,
                    output_graph,
                    quantization_dtype=None):
  """Takes a Python GraphDef object and extract the weights.

  Args:
    graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    quantization_dtype: An optional numpy dtype to quantize weights to for
        compression. Only np.uint8 and np.uint16 are supported.
  """
  constants = [node for node in graph_def.node if node.op == 'Const']
  constInputs = {}
  # removed the conditional inputs for constants
  for const in constants:
    constInputs[const.name] = const.input[:]
    del const.input[:]

  print('Writing weight file ' + output_graph + '...')
  const_manifest = []
  path = os.path.dirname(output_graph)

  graph = tf.Graph()
  with tf.Session(graph=graph) as sess:
    tf.import_graph_def(graph_def, name='')
    for const in constants:
      tensor = graph.get_tensor_by_name(const.name + ':0')
      value = tensor.eval(session=sess)
      if not isinstance(value, np.ndarray):
        value = np.array(value)

      # Restore the conditional inputs
      const_manifest.append({'name': const.name, 'data': value})
      const.input[:] = constInputs[const.name]

      # Remove the binary array from tensor and save it to the external file.
      const.attr["value"].tensor.ClearField('tensor_content')

  write_weights.write_weights(
      [const_manifest], path, quantization_dtype=quantization_dtype)

  file_io.atomic_write_string_to_file(
      os.path.abspath(output_graph), graph_def.SerializeToString())
def dump_checkpoint(
    checkpoint_file, output_dir, remove_variables_regex=None,
    quantization_dtype=None):
  reader = tf.train.NewCheckpointReader(checkpoint_file)
  var_to_shape_map = reader.get_variable_to_shape_map()

  remove_variables_regex_re = (
      re.compile(remove_variables_regex) if remove_variables_regex else None)

  entries = []
  for var_name in var_to_shape_map:
    if (remove_variables_regex_re and remove_variables_regex_re.match(var_name)
        or var_name == 'global_step'):
      print('Ignoring ' + var_name)
      continue

    tensor = reader.get_tensor(var_name)
    entries.append({'name': var_name, 'data': tensor})
    print('Dumping ' + var_name)

  write_weights([entries], output_dir, quantization_dtype=quantization_dtype)
Example #33
0
    def test_1_group_1_weight_string(self):
        groups = [[{
            'name':
            'weight1',
            'data':
            np.array([['здраво', 'end'], ['test', 'a']], 'object')
        }]]

        manifest = write_weights.write_weights(groups,
                                               TMP_DIR,
                                               shard_size_bytes=4 * 1024 *
                                               1024)

        self.assertTrue(
            os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
            'weights_manifest.json does not exist')

        self.assertEqual(manifest, [{
            'paths': ['group1-shard1of1.bin'],
            'weights': [{
                'name': 'weight1',
                'shape': [2, 2],
                'dtype': 'string'
            }]
        }])

        weights_path = os.path.join(TMP_DIR, 'group1-shard1of1.bin')
        with open(weights_path, 'rb') as f:
            weight_bytes = f.read()

            self.assertEqual(len(weight_bytes), 36)
            # 'здраво'
            size = np.frombuffer(weight_bytes[:4], 'uint32')[0]
            self.assertEqual(size, 12)  # 6 cyrillic chars (2 bytes each).
            string = weight_bytes[4:16].decode('utf-8')
            self.assertEqual(string, u'здраво')
            # 'end'
            size = np.frombuffer(weight_bytes[16:20], 'uint32')[0]
            self.assertEqual(size, 3)  # 3 ascii chars.
            string = weight_bytes[20:23].decode('utf-8')
            self.assertEqual(string, u'end')
            # 'test'
            size = np.frombuffer(weight_bytes[23:27], 'uint32')[0]
            self.assertEqual(size, 4)  # 4 ascii chars.
            string = weight_bytes[27:31].decode('utf-8')
            self.assertEqual(string, u'test')
            # 'a'
            size = np.frombuffer(weight_bytes[31:35], 'uint32')[0]
            self.assertEqual(size, 1)  # 4 ascii chars.
            string = weight_bytes[35:36].decode('utf-8')
            self.assertEqual(string, u'a')
    def testReadOneGroup(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        self.assertTrue(
            np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
    def testReadOneGroupWithShards(self):
        groups = [[{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        self.assertTrue(
            np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
Example #36
0
    def test_quantize_group(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }, {
            'name': 'weight2',
            'data': np.array([4, 5], 'int32')
        }]]

        manifest_json = write_weights.write_weights(
            groups,
            TMP_DIR,
            shard_size_bytes=8 * 4,
            quantization_dtype=np.uint8)
        manifest = json.loads(manifest_json)

        self.assertTrue(
            os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
            'weights_manifest.json does not exist')
        q, s, m = zip(
            quantization.quantize_weights(groups[0][0]['data'], np.uint8),
            quantization.quantize_weights(groups[0][1]['data'], np.uint8))
        self.assertEqual(manifest, [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32',
                'quantization': {
                    'min': m[0],
                    'scale': s[0],
                    'dtype': 'uint8'
                }
            }, {
                'name': 'weight2',
                'shape': [2],
                'dtype': 'int32',
                'quantization': {
                    'min': m[1],
                    'scale': s[1],
                    'dtype': 'uint8'
                }
            }]
        }])

        weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
        weights = np.fromfile(weights_path, 'uint8')
        np.testing.assert_array_equal(weights, np.concatenate([q[0], q[1]]))
def write_artifacts(topology,
                    weights,
                    output_graph,
                    tf_version,
                    signature_def,
                    quantization_dtype=None,
                    weight_shard_size_bytes=1024 * 1024 * 4):
    """Writes weights and topology to the output_dir.

  If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

  Args:
    topology: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    weights: an array of weight groups (as defined in tfjs write_weights).
    output_graph: the output file name to hold all the contents.
    tf_version: Tensorflow version of the input graph.
    signature_def: the SignatureDef of the inference graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
    weight_shard_size_bytes: Shard size (in bytes) of the weight files.
      The size of each weight file will be <= this value.
  """

    model_json = {
        common.FORMAT_KEY:
        common.TFJS_GRAPH_MODEL_FORMAT,
        # TODO(piyu): Add tensorflow version below by using `meta_info_def`.
        common.GENERATED_BY_KEY:
        tf_version,
        common.CONVERTED_BY_KEY:
        common.get_converted_by(),
        common.USER_DEFINED_METADATA_KEY: {
            common.SIGNATURE_KEY: MessageToDict(signature_def)
        }
    }
    model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
    weights_manifest = write_weights.write_weights(
        weights,
        os.path.dirname(output_graph),
        write_manifest=False,
        quantization_dtype=quantization_dtype,
        shard_size_bytes=weight_shard_size_bytes)
    assert isinstance(weights_manifest, list)
    model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

    with tf.io.gfile.GFile(output_graph, 'w') as f:
        json.dump(model_json, f)
  def testReadOneGroupFlattened(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(groups, self._tmp_dir)
    manifest = json.loads(manifest_json)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(
        manifest, self._tmp_dir, flatten=True)
    self.assertEqual(1, len(read_output))
    self.assertEqual('weight1', read_output[0]['name'])
    self.assertTrue(np.allclose(groups[0][0]['data'], read_output[0]['data']))
    def testReadQuantizedWeights(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([0, 1, 2, 3], 'float32')
        }]]

        manifest_json = write_weights.write_weights(
            groups, self._tmp_dir, quantization_dtype=np.uint8)
        manifest = json.loads(manifest_json)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        self.assertTrue(
            np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
    def testReadOneGroupString(self):
        groups = [[{
            'name': 'weight1',
            'data': np.array([['test', 'a'], ['b', 'c']], 'object')
        }]]

        manifest = write_weights.write_weights(groups, self._tmp_dir)

        # Read the weights using `read_weights`.
        read_output = read_weights.read_weights(manifest, self._tmp_dir)
        self.assertEqual(1, len(read_output))
        self.assertEqual(1, len(read_output[0]))
        self.assertEqual('weight1', read_output[0][0]['name'])
        np.testing.assert_array_equal(
            read_output[0][0]['data'],
            np.array([[u'test'.encode('utf-8'), u'a'.encode('utf-8')],
                      [u'b'.encode('utf-8'), u'c'.encode('utf-8')]], 'object'))
  def test_quantize_group(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }, {
            'name': 'weight2',
            'data': np.array([4, 5], 'int32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=8 * 4, quantization_dtype=np.uint8)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')
    q, s, m = zip(
        quantization.quantize_weights(groups[0][0]['data'], np.uint8),
        quantization.quantize_weights(groups[0][1]['data'], np.uint8))
    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32',
                'quantization': {
                    'min': m[0], 'scale': s[0], 'dtype': 'uint8'
                }
            }, {
                'name': 'weight2',
                'shape': [2],
                'dtype': 'int32',
                'quantization': {
                    'min': m[1], 'scale': s[1], 'dtype': 'uint8'
                }
            }]
        }])

    weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
    weights = np.fromfile(weights_path, 'uint8')
    np.testing.assert_array_equal(weights, np.concatenate([q[0], q[1]]))
  def testReadOneGroupWithShards(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.random.rand(1, 100).astype(np.float32)
        }]
    ]

    manifest_json = write_weights.write_weights(groups, self._tmp_dir)
    manifest = json.loads(manifest_json)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    self.assertTrue(
        np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
  def testReadQuantizedWeights(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([0, 1, 2, 3], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, self._tmp_dir, quantization_dtype=np.uint8)
    manifest = json.loads(manifest_json)

    # Read the weights using `read_weights`.
    read_output = read_weights.read_weights(manifest, self._tmp_dir)
    self.assertEqual(1, len(read_output))
    self.assertEqual(1, len(read_output[0]))
    self.assertEqual('weight1', read_output[0][0]['name'])
    self.assertTrue(
        np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
  def write_artifacts(self,
                      topology,
                      weights,
                      output_dir,
                      quantization_dtype=None):
    """Writes weights and topology to the output_dir.

    If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

    Args:
      topology: a JSON dictionary, representing the Keras config.
      weights: an array of weight groups (as defined in tfjs write_weights).
      output_dir: the directory to hold all the contents.
      quantization_dtype: An optional numpy dtype to quantize weights to for
        compression. Only np.uint8 and np.uint16 are supported.
    """
    # TODO(cais, nielsene): This method should allow optional arguments of
    #   `write_weights.write_weights` (e.g., shard size) and forward them.
    # We write the topology after since write_weights makes no promises about
    # preserving directory contents.
    if os.path.isfile(output_dir):
      raise ValueError(
          'Path "%d" already exists as a file (not a directory).' % output_dir)

    model_json = {}

    model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
    weights_manifest = write_weights.write_weights(
        weights, output_dir, write_manifest=False,
        quantization_dtype=quantization_dtype)
    if not isinstance(weights_manifest, list):
      weights_manifest = json.loads(weights_manifest)
    assert isinstance(weights_manifest, list)
    model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

    model_json_path = os.path.join(output_dir, ARTIFACT_MODEL_JSON_FILE_NAME)
    with open(model_json_path, 'wt') as f:
      json.dump(model_json, f)
  def test_1_group_2_weights_packed(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }, {
            'name': 'weight2',
            'data': np.array([4, 5], 'float32')
        }]
    ]

    # Shard size is larger than the sum of the two weights so they get packed.
    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=8 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')
    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of1'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }, {
                'name': 'weight2',
                'shape': [2],
                'dtype': 'float32'
            }]
        }])

    weights_path = os.path.join(TMP_DIR, 'group1-shard1of1')
    weights = np.fromfile(weights_path, 'float32')
    np.testing.assert_array_equal(weights, np.array([1, 2, 3, 4, 5], 'float32'))
  def test_1_group_2_packed_sharded_multi_dtype(self):
    groups = [
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'int32')
        }, {
            'name': 'weight2',
            'data': np.array([True, False, False, True], 'bool')
        }, {
            'name': 'weight3',
            'data': np.array([4.1, 5.1], 'float32')
        }]
    ]

    # Shard size is smaller than the sum of the weights so they get packed and
    # then sharded. The two buffers will get sharded into 3 files, with shapes
    # [2], [2], and [1]. The second shard is a mixed dtype.
    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=2 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')
    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of3',
                      'group1-shard2of3',
                      'group1-shard3of3'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'int32'
            }, {
                'name': 'weight2',
                'shape': [4],
                'dtype': 'bool'
            }, {
                'name': 'weight3',
                'shape': [2],
                'dtype': 'float32'
            }]
        }])

    shard_1_path = os.path.join(TMP_DIR, 'group1-shard1of3')
    shard_1 = np.fromfile(shard_1_path, 'int32')
    np.testing.assert_array_equal(shard_1, np.array([1, 2], 'int32'))

    # Shard 2 has a mixed dtype so we parse the bytes directly.
    shard_2_path = os.path.join(TMP_DIR, 'group1-shard2of3')
    with open(shard_2_path, 'rb') as f:
      shard_2_bytes = f.read()
    shard_2_int = np.frombuffer(shard_2_bytes[:4], 'int32')
    np.testing.assert_array_equal(shard_2_int, np.array([3], 'int32'))
    shard_2_bool = np.frombuffer(shard_2_bytes[4:], 'bool')
    np.testing.assert_array_equal(
        shard_2_bool, np.array([True, False, False, True], 'bool'))

    shard_3_path = os.path.join(TMP_DIR, 'group1-shard3of3')
    shard_3 = np.fromfile(shard_3_path, 'float32')
    np.testing.assert_array_equal(shard_3, np.array([4.1, 5.1], 'float32'))
  def test_2_groups_4_weights_sharded_packed(self):
    groups = [
        # Group 1
        [{
            'name': 'weight1',
            'data': np.array([1, 2, 3], 'float32')
        }, {
            'name': 'weight2',
            'data': np.array([[4, 5], [6, 7]], 'float32')
        }],
        # Group 2
        [{
            'name': 'weight3',
            'data': np.array([1, 2, 3, 4], 'int32')
        }, {
            'name': 'weight4',
            'data': np.array([[1.1, 1.2], [1.3, 1.4], [1.5, 1.6]], 'float32')
        }]
    ]

    manifest_json = write_weights.write_weights(
        groups, TMP_DIR, shard_size_bytes=4 * 4)
    manifest = json.loads(manifest_json)

    self.assertTrue(
        os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
        'weights_manifest.json does not exist')
    self.assertEqual(
        manifest,
        [{
            'paths': ['group1-shard1of2', 'group1-shard2of2'],
            'weights': [{
                'name': 'weight1',
                'shape': [3],
                'dtype': 'float32'
            }, {
                'name': 'weight2',
                'shape': [2, 2],
                'dtype': 'float32'
            }]
        }, {
            'paths': ['group2-shard1of3',
                      'group2-shard2of3',
                      'group2-shard3of3'],
            'weights': [{
                'name': 'weight3',
                'shape': [4],
                'dtype': 'int32'
            }, {
                'name': 'weight4',
                'shape': [3, 2],
                'dtype': 'float32'
            }]
        }])

    group0_shard_1_path = os.path.join(TMP_DIR, 'group1-shard1of2')
    group0_shard_1 = np.fromfile(group0_shard_1_path, 'float32')
    np.testing.assert_array_equal(
        group0_shard_1, np.array([1, 2, 3, 4], 'float32'))

    group0_shard_2_path = os.path.join(TMP_DIR, 'group1-shard2of2')
    group0_shard_2 = np.fromfile(group0_shard_2_path, 'float32')
    np.testing.assert_array_equal(
        group0_shard_2, np.array([5, 6, 7], 'float32'))

    group1_shard_1_path = os.path.join(TMP_DIR, 'group2-shard1of3')
    group1_shard_1 = np.fromfile(group1_shard_1_path, 'int32')
    np.testing.assert_array_equal(
        group1_shard_1, np.array([1, 2, 3, 4], 'int32'))

    group2_shard_2_path = os.path.join(TMP_DIR, 'group2-shard2of3')
    group2_shard_2 = np.fromfile(group2_shard_2_path, 'float32')
    np.testing.assert_array_equal(
        group2_shard_2, np.array([1.1, 1.2, 1.3, 1.4], 'float32'))

    group2_shard_3_path = os.path.join(TMP_DIR, 'group2-shard3of3')
    group2_shard_3 = np.fromfile(group2_shard_3_path, 'float32')
    np.testing.assert_array_equal(
        group2_shard_3, np.array([1.5, 1.6], 'float32'))
 def test_no_weights_groups_throws(self):
   groups = None
   with self.assertRaises(Exception):
     write_weights.write_weights(groups, TMP_DIR)
 def test_empty_groups_throws(self):
   groups = []
   with self.assertRaises(Exception):
     write_weights.write_weights(groups, TMP_DIR)