Exemplo n.º 1
0
    def _AddShardedSaveOps(self, filename_tensor, per_device):
        """Add ops to save the params per shard.

    Args:
      filename_tensor: String Tensor.
      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
        returned by _GroupByDevices().

    Returns:
      An op to save the variables.
    """
        num_shards = len(per_device)
        sharded_saves = []
        num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
        for shard, (device, vars_to_save) in enumerate(per_device):
            with ops.device(device):
                sharded_filename = self.sharded_filename(
                    filename_tensor, shard, num_shards_tensor)
                sharded_saves.append(
                    self._AddSaveOps(sharded_filename, vars_to_save))
        # Return the sharded name for the save path.
        with ops.control_dependencies([x.op for x in sharded_saves]):
            # pylint: disable=protected-access
            return gen_io_ops._sharded_filespec(filename_tensor,
                                                num_shards_tensor)
 def testShardedFileName(self):
   with tf.Session(
       target="",
       config=tf.ConfigProto(device_count={"CPU": 2})):
     self.assertEqual(gen_io_ops._sharded_filename("foo", 4, 100).eval(),
                      b"foo-00004-of-00100")
     self.assertEqual(gen_io_ops._sharded_filespec("foo", 100).eval(),
                      b"foo-?????-of-00100")
Exemplo n.º 3
0
    def _AddShardedSaveOps(self, filename_tensor, per_device):
        """Add ops to save the params per shard.

    Args:
      filename_tensor: String Tensor.
      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
        returned by _GroupByDevices().

    Returns:
      An op to save the variables.
    """
        num_shards = len(per_device)
        sharded_saves = []
        num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
        for shard, (device, vars_to_save) in enumerate(per_device):
            with ops.device(device):
                sharded_filename = self.sharded_filename(filename_tensor, shard, num_shards_tensor)
                sharded_saves.append(self._AddSaveOps(sharded_filename, vars_to_save))
        # Return the sharded name for the save path.
        with ops.control_dependencies([x.op for x in sharded_saves]):
            return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)