def _set_mht_saveable_checkpoint_initializer(saveable,
                                             ckpt_file,
                                             saveable_name_in_ckpt,
                                             name="mht_checkpoint_initializer"):
  canonical_device = set(
      pydev.canonical_name(spec.tensor.device) for spec in saveable.specs)
  if len(canonical_device) != 1:
    raise ValueError("All tensors of a saveable object must be "
                     "on the same device: %s" % saveable.name)
  device = canonical_device.pop()
  keys_spec, values_spec = saveable.specs
  with ops.device(device), ops.device("/cpu:0"):
    restore_keys_op = io_ops.restore_v2(ckpt_file,
                                        [saveable_name_in_ckpt + "-keys"],
                                        [keys_spec.slice_spec],
                                        [keys_spec.dtype],
                                        name=name)[0]
    restore_values_op = io_ops.restore_v2(ckpt_file,
                                          [saveable_name_in_ckpt + "-values"],
                                          [values_spec.slice_spec],
                                          [values_spec.dtype],
                                          name=name)[0]
  init_op = saveable.restore([restore_keys_op, restore_values_op],
                             restored_shapes=None)
  ops.add_to_collection(ops.GraphKeys.DYNAMIC_EMBEDDING_VARIABLE_INITIALIZERS,
                        init_op)
Esempio n. 2
0
 def _process_slot_restoration(self, slot_restoration, variable):
     """Restore a slot variable's value (creating it if necessary)."""
     # TODO(allenl): Move this to Optimizer
     assert isinstance(self, optimizer_lib.Optimizer)
     named_slots = self._slot_dict(slot_restoration.slot_name)
     variable_key = optimizer_lib._var_key(variable)  # pylint: disable=protected-access
     existing_slot_variable = named_slots.get(variable_key, None)
     if existing_slot_variable is None:
         base_dtype = slot_restoration.value_pointer.dtype.base_dtype
         initializer, = io_ops.restore_v2(
             prefix=slot_restoration.value_pointer.save_path,
             tensor_names=[slot_restoration.value_pointer.checkpoint_key],
             shape_and_slices=[""],
             dtypes=[base_dtype],
             name="checkpoint_initializer")
         new_slot_variable = slot_creator.create_slot(
             variable, initializer, slot_restoration.slot_name)
         if slot_restoration.value_pointer.session is not None:
             slot_restoration.value_pointer.session.run(
                 new_slot_variable.initializer)
         named_slots[variable_key] = new_slot_variable
     else:
         _assign_existing_variable(
             existing_slot_variable,
             value_pointer=slot_restoration.value_pointer)
Esempio n. 3
0
def ReadNpArrays(file_prefix, nmap):
  """Reads from a tf checkpoint to fill in values of a NesteMap.

  Args:
    file_prefix: A TF checkpoint filename prefix.
    nmap: A NestedMap of numpy dtypes.

  Returns:
    A NestedMap with numpy arrays compatible w/ nmap.
  """
  g = tf.Graph()
  with g.as_default():
    reads = []
    for name, dtype in nmap.FlattenItems():
      reads.append(
          io_ops.restore_v2(
              prefix=file_prefix,
              tensor_names=[name],
              shape_and_slices=[""],
              dtypes=[dtype])[0])

  with tf.Session(graph=g) as sess:
    vals = sess.run(reads)

  return nmap.Pack(vals)
Esempio n. 4
0
    def restore(self, file_prefix):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.

    Returns:
      A scalar string Tensor containing `file_prefix` with control dependencies
      on the restore ops.
    """
        restore_specs = []
        tensor_structure = []
        for saveable in self._saveable_objects:
            saveable_tensor_structure = []
            tensor_structure.append(saveable_tensor_structure)
            for spec in saveable.specs:
                saveable_tensor_structure.append(spec.name)
                restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
        tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
        with ops.device("cpu:0"):
            restored_tensors = io_ops.restore_v2(file_prefix, tensor_names,
                                                 tensor_slices, tensor_dtypes)
        structured_restored_tensors = nest.pack_sequence_as(
            tensor_structure, restored_tensors)
        for saveable, restored_tensors in zip(self._saveable_objects,
                                              structured_restored_tensors):
            saveable.restore(restored_tensors, restored_shapes=None)
        return file_prefix
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    # Do not colocate with variable since RestoreV2 op only runs on CPU and
    # colocation will force variable (and other ops that colocate with variable)
    # to be on CPU as well. It is okay to place the variable's initializer op on
    # CPU since it will only be run once at the start.
    with ops.device(variable.device), ops.device("/cpu:0"):
        restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec],
                                       [base_type],
                                       name=name)[0]
        if isinstance(variable, resource_variable_ops.ResourceVariable):
            init_op = variable.assign(restore_op, read_value=False)
        else:
            init_op = state_ops.assign(variable, restore_op)
        variable._initializer_op = init_op  # pylint:disable=protected-access
        restore_op.set_shape(variable.shape)
        variable._initial_value = restore_op  # pylint:disable=protected-access
 def testRelativePath(self):
   os.chdir(self.get_temp_dir())
   self.evaluate(io_ops.save_v2(
       "ckpt", ["x"], [""], [constant_op.constant(100.)]))
   self.assertAllEqual([100.],
                       self.evaluate(io_ops.restore_v2(
                           "ckpt", ["x"], [""], [dtypes.float32])))
Esempio n. 7
0
    def restore(self, file_prefix):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.

    Returns:
      An operation which restores the `Saver`'s `SaveableObject`s when run, or
      None if executing eagerly.
    """
        restore_ops = []
        for saveable in self._saveable_objects:
            if saveable.device:
                device = saveable_object_util.set_cpu0(saveable.device)
            else:
                device = None
            with ops.device(device):
                tensors = []
                for spec in saveable.specs:
                    tensors.append(
                        io_ops.restore_v2(file_prefix, [spec.name],
                                          [spec.slice_spec], [spec.dtype])[0])
                restore_ops.append(
                    saveable.restore(tensors, restored_shapes=None))
        return control_flow_ops.group(restore_ops)
def restore_tfra_variable(ckpt_path, variable):
  mhts = variable._tables
  key_tensor_names = []
  value_tensor_names = []
  for mht in mhts:
    assert len(mht.saveable.specs) == 2
    key_tensor_names.append(mht.saveable.specs[0].name)
    value_tensor_names.append(mht.saveable.specs[1].name)

  latest_ckpt = _get_checkpoint_filename(ckpt_path)
  restore_op = io_ops.restore_v2(latest_ckpt,
                                 key_tensor_names + value_tensor_names, [""] *
                                 len(key_tensor_names + value_tensor_names),
                                 ([tf.int64] * len(key_tensor_names)) +
                                 ([tf.float32] * len(value_tensor_names)))

  key_tensor = restore_op[:len(key_tensor_names)]
  value_tensor = restore_op[len(key_tensor_names):]

  mht_restore_ops = []
  index = 0
  for mht in mhts:
    mht_restore_ops.append(
        mht.saveable.restore((key_tensor[index], value_tensor[index]), None))
    index += 1
  return mht_restore_ops
Esempio n. 9
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    with ops.colocate_with(variable):
        restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec],
                                       [base_type],
                                       name=name)[0]
        variable._initializer_op = state_ops.assign(variable, restore_op)  # pylint:disable=protected-access
        restore_op.set_shape(variable.shape)
        variable._initial_value = restore_op  # pylint:disable=protected-access
Esempio n. 10
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  # Do not colocate with variable since RestoreV2 op only runs on CPU and
  # colocation will force variable (and other ops that colocate with variable)
  # to be on CPU as well. It is okay to place the variable's initializer op on
  # CPU since it will only be run once at the start.
  with ops.device(variable.device), ops.device("/cpu:0"):
    restore_op = io_ops.restore_v2(
        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
    if isinstance(variable, resource_variable_ops.ResourceVariable):
      init_op = variable.assign(restore_op, read_value=False)
    else:
      init_op = state_ops.assign(variable, restore_op)
    variable._initializer_op = init_op  # pylint:disable=protected-access
    restore_op.set_shape(variable.shape)
    variable._initial_value = restore_op  # pylint:disable=protected-access
Esempio n. 11
0
  def restore(self, file_prefix, options=None):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
    options = options or checkpoint_options.CheckpointOptions()
    restore_specs = []
    tensor_structure = []
    for saveable in self._saveable_objects:
      saveable_tensor_structure = []
      tensor_structure.append(saveable_tensor_structure)
      for spec in saveable.specs:
        saveable_tensor_structure.append(spec.name)
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
    restore_device = options.experimental_io_device or "cpu:0"
    with ops.device(restore_device):
      restored_tensors = io_ops.restore_v2(
          file_prefix, tensor_names, tensor_slices, tensor_dtypes)
    structured_restored_tensors = nest.pack_sequence_as(
        tensor_structure, restored_tensors)
    restore_ops = {}
    for saveable, restored_tensors in zip(self._saveable_objects,
                                          structured_restored_tensors):
      restore_ops[saveable.name] = saveable.restore(
          restored_tensors, restored_shapes=None)
    return restore_ops
Esempio n. 12
0
  def restore(self, file_prefix):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
    restore_specs = []
    tensor_structure = []
    for saveable in self._saveable_objects:
      saveable_tensor_structure = []
      tensor_structure.append(saveable_tensor_structure)
      for spec in saveable.specs:
        saveable_tensor_structure.append(spec.name)
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
    with ops.device("cpu:0"):
      restored_tensors = io_ops.restore_v2(
          file_prefix, tensor_names, tensor_slices, tensor_dtypes)
    structured_restored_tensors = nest.pack_sequence_as(
        tensor_structure, restored_tensors)
    restore_ops = {}
    for saveable, restored_tensors in zip(self._saveable_objects,
                                          structured_restored_tensors):
      restore_ops[saveable.name] = saveable.restore(
          restored_tensors, restored_shapes=None)
    return restore_ops
Esempio n. 13
0
 def testRestoreV2WithSliceInput(self):
     with ops.Graph().as_default():
         op = io_ops.restore_v2("model", ["var1", "var2"],
                                ["", "3 4 0,1:-"],
                                [dtypes.float32, dtypes.float32])
         self.assertEqual(2, len(op))
         self.assertFalse(op[0].get_shape().is_fully_defined())
         self.assertEqual([1, 4], op[1].get_shape())
Esempio n. 14
0
    def restore_op(self, filename_tensor, saveable, preferred_shard):
        tensors = []
        for spec in saveable.specs:
            print(spec.name)
            tensors.append(
                io_ops.restore_v2(filename_tensor, [spec.name],
                                  [spec.slice_spec], [spec.tensor.dtype])[0])

        return tensors
Esempio n. 15
0
 def restore_fn(trackables, merged_prefix):
     tensor_names, shapes_and_slices, tensors, restored_trackables = (
         _get_tensors(trackables))
     dtypes = [t.dtype for t in tensors]
     try:
         restored_tensors = io_ops.restore_v2(merged_prefix,
                                              tensor_names,
                                              shapes_and_slices, dtypes)
     except errors_impl.NotFoundError:
         # If a NotFoundError is caught, then it means that the checkpoint
         # was written prior to the saver registration migration.
         tensor_names, shapes_and_slices, tensors, restored_trackables = (
             _get_tensors(trackables, append_name=False))
         restored_tensors = io_ops.restore_v2(merged_prefix,
                                              tensor_names,
                                              shapes_and_slices, dtypes)
     for trackable, name_tensor in zip(restored_trackables,
                                       restored_tensors):
         trackable.name = name_tensor
Esempio n. 16
0
    def restore_op(self, filename_tensor, saveable, preferred_shard):
        tensors = []
        for spec in saveable.specs:
            # Ignore the moving_mean and moving_variance in other towers.
            if spec.name.startswith('replicated_'):
                if not spec.name.startswith(
                        'replicated_0') and 'BatchNorm/moving_' in spec.name:
                    continue
                tensors.append(
                    io_ops.restore_v2(filename_tensor,
                                      ['/'.join(spec.name.split('/')[1:])],
                                      [spec.slice_spec],
                                      [spec.tensor.dtype])[0])
            else:
                tensors.append(
                    io_ops.restore_v2(filename_tensor, [spec.name],
                                      [spec.slice_spec],
                                      [spec.tensor.dtype])[0])

        return tensors
Esempio n. 17
0
 def _BuildRestore(self):
   """Builds restore ops."""
   assign_ops = []
   for var in self._vars:
     val, = io_ops.restore_v2(
         prefix=self._restore_prefix_ph,
         tensor_names=[_VarKey(var)],
         shape_and_slices=[""],
         dtypes=[var.dtype])
     assign_ops.append(var.assign(val))
   self._restore_op = tf.group(*assign_ops)
Esempio n. 18
0
 def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                  restore_sequentially):
     from tensorflow.python.ops import io_ops
     restore_specs = []
     for saveable in saveables:
         for spec in saveable.specs:
             restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
     names, slices, dtypes = zip(*restore_specs)
     restore_dtypes = [tf.float32 for _ in dtypes]
     with tf.device("cpu:0"):
         restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
         return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
Esempio n. 19
0
def restore_stacks_and_parts(trackables, merged_prefix):
  tensor_names, shapes_and_slices, tensors, restored_trackables = (
      get_tensor_slices(trackables))
  dtypes = [t.dtype for t in tensors]
  restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
                                       shapes_and_slices, dtypes)
  for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
    expected_shape = trackable.value().get_shape()
    restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
    parts = array_ops.unstack(restored_tensor)
    for part, restored_part in zip(trackable.parts, parts):
      part.assign(restored_part)
Esempio n. 20
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    # Do not colocate with variable since RestoreV2 op only runs on CPU and
    # colocation will force variable (and other ops that colocate with variable)
    # to be on CPU as well. It is okay to place the variable's initializer op on
    # CPU since it will only be run once at the start.
    with ops.device(variable.device), ops.device("/cpu:0"):
        restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec],
                                       [base_type],
                                       name=name)[0]

        # TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
        if resource_variable_ops.is_resource_variable(variable):
            init_op = variable.assign(restore_op, read_value=False)
        else:
            init_op = state_ops.assign(variable, restore_op)

        # pylint:disable=protected-access
        # We need special handling for `DistributedVariable`s as they contain
        # mutliple actual variables. `assign` on a `DistributedVariable` returns a
        # combined `init_op` which contains initializers for all the contained
        # variables. We then set each underlying variable's `_initializer_op` using
        # the corresponding `init_op`.
        # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class
        # moves out of contrib.
        if any(base.__name__ == "DistributedVariable"
               for base in variable.__class__.__bases__):
            assert distribute_lib.get_cross_tower_context()
            assert hasattr(variable, "_index")
            for (d, v) in six.iteritems(variable._index):
                v._initializer_op = init_op._index[d]
                restore_op.set_shape(v.shape)
                v._initial_value = restore_op
        else:
            variable._initializer_op = init_op
            restore_op.set_shape(variable.shape)
            variable._initial_value = restore_op
Esempio n. 21
0
    def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                   restore_sequentially):
        # Ignored: bulk restore is internally sequential.
        del restore_sequentially
        restore_specs = []
        for saveable in saveables:
            for spec in saveable.specs:
                restore_specs.append((spec.name, spec.slice_spec, spec.dtype))

        names, slices, dtypes = zip(*restore_specs)
        # Load all tensors onto CPU 0 for compatibility with existing code.
        with ops.device("cpu:0"):
            return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
Esempio n. 22
0
def _assign_existing_variable(variable_to_restore, value_pointer):
  """Set a variable from a _ValuePointer object."""
  base_type = variable_to_restore.dtype.base_dtype
  with ops.colocate_with(variable_to_restore):
    # TODO(allenl): Handle partitioned variables
    value_to_restore, = io_ops.restore_v2(
        prefix=value_pointer.save_path,
        tensor_names=[value_pointer.checkpoint_key],
        shape_and_slices=[""],
        dtypes=[base_type],
        name="checkpoint_initializer")
    initializer_op = state_ops.assign(variable_to_restore, value_to_restore)
    variable_to_restore._initializer_op = initializer_op  # pylint:disable=protected-access
    if value_pointer.session is not None:
      value_pointer.session.run(initializer_op)
Esempio n. 23
0
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
                                name="checkpoint_initializer"):
  """Sets variable initializer to assign op form value in checkpoint's tensor.

  Args:
    variable: `Variable` object.
    file_pattern: string, where to load checkpoints from.
    tensor_name: Name of the `Tensor` to load from checkpoint reader.
    slice_spec: Slice specification for loading partitioned variables.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  restore_op = io_ops.restore_v2(
      file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
  variable._initializer_op = state_ops.assign(variable, restore_op)
 def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                  restore_sequentially):
     restore_specs = []
     for saveable in saveables:
         for spec in saveable.specs:
             restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
     names, slices, dtypes = zip(*restore_specs)
     restore_dtypes = [tf.float32 if dtype.base_dtype==tf.float16 else dtype for dtype in dtypes]
     # print info
     for i in range(len(restore_specs)):
         print(names[i], 'from', restore_dtypes[i], 'to', dtypes[i].base_dtype)
     with tf.device("cpu:0"):
         restored = io_ops.restore_v2(
             filename_tensor, names, slices, restore_dtypes)
         return [tf.cast(r, dt.base_dtype) for r, dt in zip(restored, dtypes)]
Esempio n. 25
0
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
                                name="checkpoint_initializer"):
  """Sets variable initializer to assign op form value in checkpoint's tensor.

  Args:
    variable: `Variable` object.
    file_pattern: string, where to load checkpoints from.
    tensor_name: Name of the `Tensor` to load from checkpoint reader.
    slice_spec: Slice specification for loading partitioned variables.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  restore_op = io_ops.restore_v2(
      file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
  variable._initializer_op = state_ops.assign(variable, restore_op)
Esempio n. 26
0
def _assign_existing_variable(variable_to_restore, value_pointer):
    """Set a variable from a _ValuePointer object."""
    base_type = variable_to_restore.dtype.base_dtype
    with ops.colocate_with(variable_to_restore):
        # TODO(allenl): Handle partitioned variables
        value_to_restore, = io_ops.restore_v2(
            prefix=value_pointer.save_path,
            tensor_names=[value_pointer.checkpoint_key],
            shape_and_slices=[""],
            dtypes=[base_type],
            name="checkpoint_initializer")
        initializer_op = state_ops.assign(variable_to_restore,
                                          value_to_restore)
        variable_to_restore._initializer_op = initializer_op  # pylint:disable=protected-access
        if value_pointer.session is not None:
            value_pointer.session.run(initializer_op)
Esempio n. 27
0
 def __call__(self, shape, dtype=None, partition_info=None):
     # Creating different RestoreV2 ops when a single one could
     # output several tensors seems inefficient, but that's actually
     # what tf.Saver.restore_op (via tf.BaseSaverBuilder) does too.
     if self._scope is None:
         scope_name = tf.get_variable_scope().name
     elif callable(self._scope):
         scope_name = self._scope(tf.get_variable_scope().name)
     else:
         scope_name = self._scope
     tensor_name = self._var_name
     if scope_name:
         tensor_name = '{}/{}'.format(scope_name, tensor_name)
     tensor = io_ops.restore_v2(
         self._filename, [tensor_name],
         [self._partition_spec(shape, partition_info)], [dtype])[0]
     tensor.set_shape(shape)
     return tensor
Esempio n. 28
0
    def _BuildRestore(self):
        """Builds restore ops."""
        assign_ops = []
        with self._var_graph.as_default():
            per_device = collections.defaultdict(lambda: [])
            for var in self._vars:
                per_device[var.device].append(var)

            for device, var_list in per_device.items():
                with self._var_graph.device(device):
                    for var in var_list:
                        val, = io_ops.restore_v2(
                            prefix=self._restore_prefix_ph,
                            tensor_names=[_VarKey(var)],
                            shape_and_slices=[""],
                            dtypes=[var.dtype])
                        assign_ops.append(var.assign(val))

        self._restore_op = tf.group(*assign_ops)
Esempio n. 29
0
 def __call__(self, shape, dtype=None, partition_info=None):
     # Creating different RestoreV2 ops when a single one could
     # output several tensors seems inefficient, but that's actually
     # what tf.Saver.restore_op (via tf.BaseSaverBuilder) does too.
     if self._scope is None:
         scope_name = tf.get_variable_scope().name
     elif callable(self._scope):
         scope_name = self._scope(tf.get_variable_scope().name)
     else:
         scope_name = self._scope
     tensor_name = self._var_name
     if scope_name:
         tensor_name = '{}/{}'.format(scope_name, tensor_name)
     tensor = io_ops.restore_v2(
         self._filename,
         [tensor_name],
         [self._partition_spec(shape, partition_info)],
         [dtype])[0]
     tensor.set_shape(shape)
     return tensor
Esempio n. 30
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    # Do not colocate with variable since RestoreV2 op only runs on CPU and
    # colocation will force variable (and other ops that colocate with variable)
    # to be on CPU as well. It is okay to place the variable's initializer op on
    # CPU since it will only be run once at the start.
    with ops.device(variable.device), ops.device("/cpu:0"):
        restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec],
                                       [base_type],
                                       name=name)[0]

        names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
        saveable_objects = []
        for name, op in names_to_saveables.items():
            for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
                saveable_objects.append(s)

        assert len(saveable_objects) == 1  # Should be only one variable.
        init_op = saveable_objects[0].restore([restore_op],
                                              restored_shapes=None)

        # pylint:disable=protected-access
        variable._initializer_op = init_op
        restore_op.set_shape(variable.shape)
        variable._initial_value = restore_op
Esempio n. 31
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  restore_op = io_ops.restore_v2(
      ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
  variable._initializer_op = state_ops.assign(variable, restore_op)  # pylint:disable=protected-access
Esempio n. 32
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  # Do not colocate with variable since RestoreV2 op only runs on CPU and
  # colocation will force variable (and other ops that colocate with variable)
  # to be on CPU as well. It is okay to place the variable's initializer op on
  # CPU since it will only be run once at the start.
  with ops.device(variable.device), ops.device("/cpu:0"):
    restore_op = io_ops.restore_v2(
        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]

    # TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
    if resource_variable_ops.is_resource_variable(variable):
      init_op = variable.assign(restore_op, read_value=False)
      # TODO(priyag): Remove this when using `SaveableObject.restore` instead.
      if hasattr(init_op, "_index"):
        init_op = distribute_lib.get_distribution_strategy().group(init_op)
    else:
      init_op = state_ops.assign(variable, restore_op)

    # pylint:disable=protected-access
    variable._initializer_op = init_op
    restore_op.set_shape(variable.shape)
    variable._initial_value = restore_op
Esempio n. 33
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  # Do not colocate with variable since RestoreV2 op only runs on CPU and
  # colocation will force variable (and other ops that colocate with variable)
  # to be on CPU as well. It is okay to place the variable's initializer op on
  # CPU since it will only be run once at the start.
  with ops.device(variable.device), ops.device("/cpu:0"):
    restore_op = io_ops.restore_v2(
        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]

    names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
    saveable_objects = []
    for name, op in names_to_saveables.items():
      for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
        saveable_objects.append(s)

    assert len(saveable_objects) == 1  # Should be only one variable.
    init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)

    # pylint:disable=protected-access
    variable._initializer_op = init_op
    restore_op.set_shape(variable.shape)
    variable._initial_value = restore_op
Esempio n. 34
0
def restore_from_saveable_objects(file_prefix, saveable_objects):
    """Reads from a checkpoint and returns restore ops for `saveable_objects`s."""
    restore_specs = []
    tensor_structure = []
    for saveable in saveable_objects:
        saveable_tensor_structure = []
        tensor_structure.append(saveable_tensor_structure)
        for spec in saveable.specs:
            saveable_tensor_structure.append(spec.name)
            restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
    with ops.device("cpu:0"):
        restored_tensors = io_ops.restore_v2(file_prefix, tensor_names,
                                             tensor_slices, tensor_dtypes)
    structured_restored_tensors = nest.pack_sequence_as(
        tensor_structure, restored_tensors)
    restore_ops = []
    for saveable, restored_tensors in zip(saveable_objects,
                                          structured_restored_tensors):
        restore_ops.append(
            saveable.restore(restored_tensors, restored_shapes=None))
    return restore_ops
Esempio n. 35
0
def restore_from_saveable_objects(file_prefix, saveable_objects):
  """Reads from a checkpoint and returns restore ops for `saveable_objects`s."""
  restore_specs = []
  tensor_structure = []
  for saveable in saveable_objects:
    saveable_tensor_structure = []
    tensor_structure.append(saveable_tensor_structure)
    for spec in saveable.specs:
      saveable_tensor_structure.append(spec.name)
      restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
  tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
  with ops.device("cpu:0"):
    restored_tensors = io_ops.restore_v2(
        file_prefix, tensor_names, tensor_slices, tensor_dtypes)
  structured_restored_tensors = nest.pack_sequence_as(
      tensor_structure, restored_tensors)
  restore_ops = []
  for saveable, restored_tensors in zip(saveable_objects,
                                        structured_restored_tensors):
    restore_ops.append(saveable.restore(restored_tensors,
                                        restored_shapes=None))
  return restore_ops
Esempio n. 36
0
  def restore(self, file_prefix, options=None):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor).
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensor_dtypes = []
    slice_specs = []

    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
      for slice_spec, tensor in tensor_slices.items():
        tensor_dtypes.append(tensor.dtype)
        if isinstance(tensor, saveable_object.SaveSpec):
          slice_specs.append(tensor.slice_spec)
          tensor_names.append(tensor.name)
        else:
          slice_specs.append(slice_spec)
          tensor_names.append(checkpoint_key)

    restore_device = options.experimental_io_device or "cpu:0"
    with ops.device(restore_device):
      restored_tensors = io_ops.restore_v2(
          file_prefix, tensor_names, slice_specs, tensor_dtypes)

    restored_tensor_dict = {}
    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
      for slice_spec in tensor_slices:
        restored_tensor = restored_tensors.pop(0)
        restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = (
            restored_tensor)
    return restored_tensor_dict
Esempio n. 37
0
 def _process_slot_restoration(self, slot_restoration, variable):
   """Restore a slot variable's value (creating it if necessary)."""
   # TODO(allenl): Move this to Optimizer
   assert isinstance(self, optimizer_lib.Optimizer)
   named_slots = self._slot_dict(slot_restoration.slot_name)
   variable_key = optimizer_lib._var_key(variable)  # pylint: disable=protected-access
   existing_slot_variable = named_slots.get(variable_key, None)
   if existing_slot_variable is None:
     base_dtype = slot_restoration.value_pointer.dtype.base_dtype
     initializer, = io_ops.restore_v2(
         prefix=slot_restoration.value_pointer.save_path,
         tensor_names=[slot_restoration.value_pointer.checkpoint_key],
         shape_and_slices=[""],
         dtypes=[base_dtype],
         name="checkpoint_initializer")
     new_slot_variable = slot_creator.create_slot(variable, initializer,
                                                  slot_restoration.slot_name)
     if slot_restoration.value_pointer.session is not None:
       slot_restoration.value_pointer.session.run(
           new_slot_variable.initializer)
     named_slots[variable_key] = new_slot_variable
   else:
     _assign_existing_variable(
         existing_slot_variable, value_pointer=slot_restoration.value_pointer)
Esempio n. 38
0
 def restore_op(self, filename_tensor, saveable, preferred_shard):
     return [io_ops.restore_v2(filename_tensor,
                               [unscope_name(spec.name)],
                               [spec.slice_spec],
                               [spec.tensor.dtype])[0] for spec in saveable.specs]
def restore(save_path, root_checkpointable, session=None):
  """Restore a training checkpoint.

  Restores the values of variables created with `Checkpointable._add_variable`
  in `root_checkpointable` and any objects that it tracks (transitive). Either
  assigns values immediately if variables to restore have been created already,
  or defers restoration until the variables are created. Dependencies added to
  `root_checkpointable` after this call will be matched if they have a
  corresponding object in the checkpoint.

  When building a graph, restorations are executed in the default session if
  `session` is `None`. Variable initializers read checkpointed values.

  To disallow deferred loading, assert immediately that all checkpointed
  variables have been matched to variable objects:

  ```python
  restore(path, root).assert_consumed()
  ```

  An exception will be raised unless every object was matched and its variables
  already exist.

  Args:
    save_path: The path to the checkpoint, as returned by `save` or
      `tf.train.latest_checkpoint`. If None (as when there is no latest
      checkpoint for `tf.train.latest_checkpoint` to return), does nothing.
    root_checkpointable: The root of the object graph to restore. Variables to
      restore need not have been created yet, but all dependencies on other
      Checkpointable objects should already be declared. Objects in the
      dependency graph are matched to objects in the checkpointed graph, and
      matching objects have their variables restored (or the checkpointed values
      saved for eventual restoration when the variable is created).
    session: The session to evaluate assignment ops in. Ignored when executing
      eagerly. If not provided when graph building, the default session is used.
  Returns:
    A CheckpointLoadStatus object, which can be used to make assertions about
    the status of checkpoint restoration.
  """
  if save_path is None:
    return
  if context.in_graph_mode():
    if session is None:
      session = ops.get_default_session()
  else:
    session = None
  object_graph_string, = io_ops.restore_v2(
      prefix=save_path,
      tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
      shape_and_slices=[""],
      dtypes=[dtypes.string],
      name="object_graph_proto_read")
  if session is not None:
    object_graph_string = session.run(object_graph_string)
  else:
    object_graph_string = object_graph_string.numpy()
  object_graph_proto = (
      checkpointable_object_graph_pb2.CheckpointableObjectGraph())
  object_graph_proto.ParseFromString(object_graph_string)
  checkpoint = core_checkpointable._Checkpoint(  # pylint: disable=protected-access
      object_graph_proto=object_graph_proto,
      save_path=save_path,
      session=session)
  core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
      checkpoint=checkpoint, proto_id=0).restore(root_checkpointable)
  return CheckpointLoadStatus(checkpoint)
Esempio n. 40
0
  def restore(self, save_path, session=None):
    """Restore a training checkpoint.

    Restores `root_checkpointable` and any objects that it tracks
    (transitive). Either assigns values immediately if variables to restore have
    been created already, or defers restoration until the variables are
    created. Dependencies added to the `root_checkpointable` passed to the
    constructor after this call will be matched if they have a corresponding
    object in the checkpoint.

    When building a graph, restorations are added to the graph but not run. A
    session is required to retrieve checkpoint metadata.

    To disallow deferred loading, assert immediately that all checkpointed
    variables have been matched to variable objects:

    ```python
    saver = Saver(root)
    saver.restore(path).assert_consumed()
    ```

    An exception will be raised unless every object was matched and its
    variables already exist.

    When graph building, `assert_consumed()` indicates that all of the restore
    ops which will be created for this checkpoint have been created. They can be
    run via the `run_restore_ops()` function of the status object:

    ```python
    saver.restore(path).assert_consumed().run_restore_ops()
    ```

    If the checkpoint has not been consumed completely, then the list of restore
    ops will grow as more objects are added to the dependency graph.

    Name-based `tf.train.Saver` checkpoints can be loaded using this
    method. There is no deferred loading, and names are used to match
    variables. No restore ops are created/run until `run_restore_ops()` or
    `initialize_or_restore()` are called on the returned status object, even
    when executing eagerly. Re-encode name-based checkpoints using this
    object-based `Saver.save` as soon as possible.

    Args:
      save_path: The path to the checkpoint, as returned by `save` or
        `tf.train.latest_checkpoint`. If None (as when there is no latest
        checkpoint for `tf.train.latest_checkpoint` to return), returns an
        object which may run initializers for objects in the dependency
        graph. If the checkpoint was written by the name-based `tf.train.Saver`,
        names are used to match variables.
      session: The session to retrieve metadata with. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.

    Returns:
      A load status object, which can be used to make assertions about the
      status of checkpoint restoration and run initialization/restore ops
      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
      `save_path` is `None`).

      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
      object is returned which runs restore ops from a name-based saver.
    """
    if save_path is None:
      return InitializationOnlyStatus(self._root_checkpointable)
    in_graph_mode = context.in_graph_mode()
    if in_graph_mode:
      if session is None:
        session = ops.get_default_session()
      file_prefix_tensor = self._file_prefix_placeholder
      file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
    else:
      session = None
      file_prefix_tensor = constant_op.constant(save_path)
      file_prefix_feed_dict = None
    try:
      if not in_graph_mode or self._object_graph_restore_tensor is None:
        object_graph_string, = io_ops.restore_v2(
            prefix=file_prefix_tensor,
            tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
            shape_and_slices=[""],
            dtypes=[dtypes.string],
            name="object_graph_proto_read")
        if in_graph_mode:
          self._object_graph_restore_tensor = object_graph_string
      if in_graph_mode:
        object_graph_string = session.run(
            self._object_graph_restore_tensor,
            feed_dict=file_prefix_feed_dict)
      else:
        object_graph_string = object_graph_string.numpy()
    except errors_impl.NotFoundError:
      # The object graph proto does not exist in this checkpoint. Try again with
      # name-based saving.
      return NameBasedSaverStatus(self, save_path)

    object_graph_proto = (
        checkpointable_object_graph_pb2.CheckpointableObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    if in_graph_mode and object_graph_proto == self._last_restore_object_graph:
      checkpoint = self._last_restore_checkpoint
    else:
      if in_graph_mode:
        dtype_map = None
      else:
        reader = pywrap_tensorflow.NewCheckpointReader(save_path)
        dtype_map = reader.get_variable_to_dtype_map()
      checkpoint = core_checkpointable_utils._Checkpoint(  # pylint: disable=protected-access
          object_graph_proto=object_graph_proto,
          save_path=file_prefix_tensor,
          dtype_map=dtype_map)
      if in_graph_mode:
        if self._last_restore_object_graph is not None:
          raise NotImplementedError(
              "Using a single Saver to restore different object graphs is not "
              "currently supported when graph building. Use a different Saver "
              "for each object graph (restore ops will be duplicated), or "
              "file a feature request if this limitation bothers you.")
        self._last_restore_checkpoint = checkpoint
        self._last_restore_object_graph = object_graph_proto
    core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
        checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
    load_status = CheckpointLoadStatus(
        checkpoint, feed_dict=file_prefix_feed_dict)
    return load_status
Esempio n. 41
0
    def restore(self, save_path, session=None):
        """Restore a training checkpoint.

    Restores `root_checkpointable` and any objects that it tracks
    (transitive). Either assigns values immediately if variables to restore have
    been created already, or defers restoration until the variables are
    created. Dependencies added to the `root_checkpointable` passed to the
    constructor after this call will be matched if they have a corresponding
    object in the checkpoint.

    When building a graph, restorations are added to the graph but not run. A
    session is required to retrieve checkpoint metadata.

    To disallow deferred loading, assert immediately that all checkpointed
    variables have been matched to variable objects:

    ```python
    saver = Saver(root)
    saver.restore(path).assert_consumed()
    ```

    An exception will be raised unless every object was matched and its
    variables already exist.

    When graph building, `assert_consumed()` indicates that all of the restore
    ops which will be created for this checkpoint have been created. They can be
    run via the `run_restore_ops()` function of the status object:

    ```python
    saver.restore(path).assert_consumed().run_restore_ops()
    ```

    If the checkpoint has not been consumed completely, then the list of restore
    ops will grow as more objects are added to the dependency graph.

    Name-based `tf.train.Saver` checkpoints can be loaded using this
    method. There is no deferred loading, and names are used to match
    variables. No restore ops are created/run until `run_restore_ops()` or
    `initialize_or_restore()` are called on the returned status object, even
    when executing eagerly. Re-encode name-based checkpoints using this
    object-based `Saver.save` as soon as possible.

    Args:
      save_path: The path to the checkpoint, as returned by `save` or
        `tf.train.latest_checkpoint`. If None (as when there is no latest
        checkpoint for `tf.train.latest_checkpoint` to return), returns an
        object which may run initializers for objects in the dependency
        graph. If the checkpoint was written by the name-based `tf.train.Saver`,
        names are used to match variables.
      session: The session to retrieve metadata with. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.

    Returns:
      A load status object, which can be used to make assertions about the
      status of checkpoint restoration and run initialization/restore ops
      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
      `save_path` is `None`).

      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
      object is returned which runs restore ops from a name-based saver.
    """
        if save_path is None:
            return InitializationOnlyStatus(self._root_checkpointable)
        in_graph_mode = context.in_graph_mode()
        if in_graph_mode:
            if session is None:
                session = ops.get_default_session()
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            session = None
            file_prefix_tensor = constant_op.constant(save_path)
            file_prefix_feed_dict = None
        try:
            if not in_graph_mode or self._object_graph_restore_tensor is None:
                object_graph_string, = io_ops.restore_v2(
                    prefix=file_prefix_tensor,
                    tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
                    shape_and_slices=[""],
                    dtypes=[dtypes.string],
                    name="object_graph_proto_read")
                if in_graph_mode:
                    self._object_graph_restore_tensor = object_graph_string
            if in_graph_mode:
                object_graph_string = session.run(
                    self._object_graph_restore_tensor,
                    feed_dict=file_prefix_feed_dict)
            else:
                object_graph_string = object_graph_string.numpy()
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try again with
            # name-based saving.
            return NameBasedSaverStatus(self, save_path)

        object_graph_proto = (
            checkpointable_object_graph_pb2.CheckpointableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        if in_graph_mode and object_graph_proto == self._last_restore_object_graph:
            checkpoint = self._last_restore_checkpoint
        else:
            if in_graph_mode:
                dtype_map = None
            else:
                reader = pywrap_tensorflow.NewCheckpointReader(save_path)
                dtype_map = reader.get_variable_to_dtype_map()
            checkpoint = core_checkpointable_utils._Checkpoint(  # pylint: disable=protected-access
                object_graph_proto=object_graph_proto,
                save_path=file_prefix_tensor,
                dtype_map=dtype_map)
            if in_graph_mode:
                if self._last_restore_object_graph is not None:
                    raise NotImplementedError(
                        "Using a single Saver to restore different object graphs is not "
                        "currently supported when graph building. Use a different Saver "
                        "for each object graph (restore ops will be duplicated), or "
                        "file a feature request if this limitation bothers you."
                    )
                self._last_restore_checkpoint = checkpoint
                self._last_restore_object_graph = object_graph_proto
        core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
            checkpoint=checkpoint,
            proto_id=0).restore(self._root_checkpointable)
        load_status = CheckpointLoadStatus(checkpoint,
                                           feed_dict=file_prefix_feed_dict)
        return load_status
Esempio n. 42
0
  def add_variable(self, name, shape=None, dtype=dtypes.float32,
                   initializer=None, **kwargs):
    """Create a new variable object to be saved with this `Checkpointable`.

    If the user has requested that this object or another `Checkpointable` which
    depends on this object be restored from a checkpoint (deferred loading
    before variable object creation), `initializer` may be ignored and the value
    from the checkpoint used instead.

    Args:
      name: A name for the variable. Must be unique within this object.
      shape: The shape of the variable.
      dtype: The data type of the variable.
      initializer: The initializer to use. Ignored if deferred loading has been
        requested.
      **kwargs: Passed to the ResourceVariable constructor.

    Returns:
      The new variable object.

    Raises:
      ValueError: If the variable name is not unique.
      RuntimeError: If __init__ has not been called.
    """
    if not hasattr(self, "_owned_variables"):
      raise RuntimeError("Need to call Checkpointable.__init__ before adding "
                         "variables.")
    if name in self._owned_variables:
      raise ValueError(
          ("A variable named '%s' already exists in this Checkpointable, but "
           "Checkpointable.add_variable called to create another with "
           "that name. Variable names must be unique within a Checkpointable "
           "object.") % (name,))
    if "getter" in kwargs:
      # Allow the getter to be overridden, typically because there is a need for
      # compatibility with some other variable creation mechanism. This should
      # be relatively uncommon in user code.
      getter = kwargs.pop("getter")
    else:
      getter = _default_getter
    deferred_restoration = self._deferred_restorations.pop(name, None)
    if deferred_restoration is not None:
      dtype = deferred_restoration.value_pointer.dtype
      base_type = dtype.base_dtype
      # TODO(allenl): Handle partitioned variables here too
      with ops.init_scope():
        initializer, = io_ops.restore_v2(
            prefix=deferred_restoration.value_pointer.save_path,
            tensor_names=[deferred_restoration.value_pointer.checkpoint_key],
            shape_and_slices=[""],
            dtypes=[base_type],
            name="checkpoint_initializer")
      # We need to un-set the shape so get_variable doesn't complain, but we
      # also need to set the static shape information on the initializer if
      # possible so we don't get a variable with an unknown shape.
      initializer.set_shape(shape)
      # Un-set shape since we're using a constant initializer
      shape = None

    new_variable = getter(
        name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs)
    if deferred_restoration is not None:
      if deferred_restoration.value_pointer.session is not None:
        deferred_restoration.value_pointer.session.run(new_variable.initializer)
      for slot_restoration in deferred_restoration.slot_restorations:
        strong_ref = slot_restoration.optimizer_ref()
        if strong_ref is None:
          # If the optimizer object has been garbage collected, there's no need
          # to create the slot variable.
          continue
        strong_ref._process_slot_restoration(  # pylint: disable=protected-access
            slot_restoration, new_variable)
    self._owned_variables[name] = new_variable
    return new_variable