def _AddShardedSaveOps(self, variables, checkpoint_prefix, var_key_fn): """Adds per-device save ops to save `variables` to `checkpoint_prefix`.""" with self._var_graph.as_default(): per_device = collections.defaultdict(lambda: []) for var in variables: per_device[var.device].append(var) tmp_save_prefix = tf.strings.join( [checkpoint_prefix, "_temp/part"]) num_shards = tf.constant(len(per_device)) sharded_saves = [] sharded_prefixes = [] for shard, (device, var_list) in enumerate(per_device.items()): with self._var_graph.device(device): sharded_filename = gen_io_ops.sharded_filename( tmp_save_prefix, shard, num_shards) sharded_prefixes.append(sharded_filename) save_op = io_ops.save_v2( prefix=sharded_filename, tensor_names=[var_key_fn(v) for v in var_list], tensors=[v.read_value() for v in var_list], shape_and_slices=[""] * len(var_list)) sharded_saves.append(save_op) with tf.control_dependencies(sharded_saves): return gen_io_ops.merge_v2_checkpoints(sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
def testShardedFileName(self): with session.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})): self.assertEqual( gen_io_ops.sharded_filename("foo", 4, 100).eval(), b"foo-00004-of-00100") self.assertEqual( gen_io_ops.sharded_filespec("foo", 100).eval(), b"foo-?????-of-00100")
def sharded_filename(filename_tensor, shard, num_shards): """Append sharding information to a filename. Args: filename_tensor: A string tensor. shard: Integer. The shard for the filename. num_shards: An int Tensor for the number of shards. Returns: A string tensor. """ return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)