Ejemplo n.º 1
0
    def _AddRestoreOps(self,
                       filename_tensor,
                       vars_to_save,
                       restore_sequentially,
                       reshape,
                       preferred_shard=-1,
                       name="restore_all"):
        """Add operations to restore vars_to_save.

    Args:
      filename_tensor: Tensor for the path of the file to load.
      vars_to_save: a list of _VarToSave objects.
      restore_sequentially: True if we want to restore variables sequentially
        within a shard.
      reshape: True if we want to reshape loaded tensors to the shape of
        the corresponding variable.
      preferred_shard: Shard to open first when loading a sharded file.
      name: Name for the returned op.

    Returns:
      An Operation that restores the variables.
    """
        assign_ops = []
        for vs in vars_to_save:
            v = vs.var
            restore_control_inputs = assign_ops[
                -1:] if restore_sequentially else []
            # Load and optionally reshape on the CPU, as string tensors are not
            # available on the GPU.
            # TODO(touts): Re-enable restore on GPU when we can support annotating
            # string tensors as "HostMemory" inputs.
            with ops.device(
                    graph_util.set_cpu0(v.device) if v.device else None):
                with ops.control_dependencies(restore_control_inputs):
                    values = self.restore_op(filename_tensor, vs,
                                             preferred_shard)
                if reshape:
                    shape = v.get_shape()
                    if not shape.is_fully_defined():
                        shape = array_ops.shape(v)
                    values = array_ops.reshape(values, shape)

            # Assign on the same device as the variable.
            with ops.device(v.device):
                assign_ops.append(
                    state_ops.assign(v, values, validate_shape=not reshape))

        # Create a Noop that has control dependencies from all the updates.
        return control_flow_ops.group(*assign_ops, name=name)
Ejemplo n.º 2
0
  def _AddRestoreOps(self,
                     filename_tensor,
                     vars_to_save,
                     restore_sequentially,
                     reshape,
                     preferred_shard=-1,
                     name="restore_all"):
    """Add operations to restore vars_to_save.

    Args:
      filename_tensor: Tensor for the path of the file to load.
      vars_to_save: A list of _VarToSave objects.
      restore_sequentially: True if we want to restore variables sequentially
        within a shard.
      reshape: True if we want to reshape loaded tensors to the shape of
        the corresponding variable.
      preferred_shard: Shard to open first when loading a sharded file.
      name: Name for the returned op.

    Returns:
      An Operation that restores the variables.
    """
    assign_ops = []
    for vs in vars_to_save:
      v = vs.var
      restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
      # Load and optionally reshape on the CPU, as string tensors are not
      # available on the GPU.
      # TODO(touts): Re-enable restore on GPU when we can support annotating
      # string tensors as "HostMemory" inputs.
      with ops.device(graph_util.set_cpu0(v.device) if v.device else None):
        with ops.control_dependencies(restore_control_inputs):
          values = self.restore_op(filename_tensor, vs, preferred_shard)
        if reshape:
          shape = v.get_shape()
          if not shape.is_fully_defined():
            shape = array_ops.shape(v)
          values = array_ops.reshape(values, shape)

      # Assign on the same device as the variable.
      with ops.device(v.device):
        assign_ops.append(state_ops.assign(v,
                                           values,
                                           validate_shape=not reshape))

    # Create a Noop that has control dependencies from all the updates.
    return control_flow_ops.group(*assign_ops, name=name)