def _AddShardedSaveOps(self, filename_tensor, per_device): """Add ops to save the params per shard. Args: filename_tensor: String Tensor. per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as returned by _GroupByDevices(). Returns: An op to save the variables. """ num_shards = len(per_device) sharded_saves = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") for shard, (device, vars_to_save) in enumerate(per_device): with ops.device(device): sharded_filename = self.sharded_filename( filename_tensor, shard, num_shards_tensor) sharded_saves.append( self._AddSaveOps(sharded_filename, vars_to_save)) # Return the sharded name for the save path. with ops.control_dependencies([x.op for x in sharded_saves]): # pylint: disable=protected-access return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)
def testShardedFileName(self): with tf.Session( target="", config=tf.ConfigProto(device_count={"CPU": 2})): self.assertEqual(gen_io_ops._sharded_filename("foo", 4, 100).eval(), b"foo-00004-of-00100") self.assertEqual(gen_io_ops._sharded_filespec("foo", 100).eval(), b"foo-?????-of-00100")
def _AddShardedSaveOps(self, filename_tensor, per_device): """Add ops to save the params per shard. Args: filename_tensor: String Tensor. per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as returned by _GroupByDevices(). Returns: An op to save the variables. """ num_shards = len(per_device) sharded_saves = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") for shard, (device, vars_to_save) in enumerate(per_device): with ops.device(device): sharded_filename = self.sharded_filename(filename_tensor, shard, num_shards_tensor) sharded_saves.append(self._AddSaveOps(sharded_filename, vars_to_save)) # Return the sharded name for the save path. with ops.control_dependencies([x.op for x in sharded_saves]): return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)