Пример #1
0
    def _AddShardedSaveOps(self, variables, checkpoint_prefix, var_key_fn):
        """Adds per-device save ops to save `variables` to `checkpoint_prefix`."""
        with self._var_graph.as_default():
            per_device = collections.defaultdict(lambda: [])
            for var in variables:
                per_device[var.device].append(var)

            tmp_save_prefix = tf.strings.join(
                [checkpoint_prefix, "_temp/part"])
            num_shards = tf.constant(len(per_device))
            sharded_saves = []
            sharded_prefixes = []

            for shard, (device, var_list) in enumerate(per_device.items()):
                with self._var_graph.device(device):
                    sharded_filename = gen_io_ops.sharded_filename(
                        tmp_save_prefix, shard, num_shards)
                    sharded_prefixes.append(sharded_filename)
                    save_op = io_ops.save_v2(
                        prefix=sharded_filename,
                        tensor_names=[var_key_fn(v) for v in var_list],
                        tensors=[v.read_value() for v in var_list],
                        shape_and_slices=[""] * len(var_list))
                    sharded_saves.append(save_op)

            with tf.control_dependencies(sharded_saves):
                return gen_io_ops.merge_v2_checkpoints(sharded_prefixes,
                                                       checkpoint_prefix,
                                                       delete_old_dirs=True)
Пример #2
0
 def testShardedFileName(self):
   with session.Session(
       target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})):
     self.assertEqual(
         gen_io_ops.sharded_filename("foo", 4, 100).eval(),
         b"foo-00004-of-00100")
     self.assertEqual(
         gen_io_ops.sharded_filespec("foo", 100).eval(), b"foo-?????-of-00100")
Пример #3
0
def sharded_filename(filename_tensor, shard, num_shards):
  """Append sharding information to a filename.

  Args:
    filename_tensor: A string tensor.
    shard: Integer.  The shard for the filename.
    num_shards: An int Tensor for the number of shards.

  Returns:
    A string tensor.
  """
  return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
Пример #4
0
def sharded_filename(filename_tensor, shard, num_shards):
  """Append sharding information to a filename.

  Args:
    filename_tensor: A string tensor.
    shard: Integer.  The shard for the filename.
    num_shards: An int Tensor for the number of shards.

  Returns:
    A string tensor.
  """
  return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)