예제 #1
0
def _get_object_checkpoint_renames(path, variable_names):
  """Returns a dictionary mapping variable names to checkpoint keys.

  The warm-starting utility expects variable names to match with the variable
  names in the checkpoint. For object-based checkpoints, the variable names
  and names in the checkpoint are different. Thus, for object-based checkpoints,
  this function is used to obtain the map from variable names to checkpoint
  keys.

  Args:
    path: path to checkpoint directory or file.
    variable_names: list of variable names to load from the checkpoint.

  Returns:
    If the checkpoint is object-based, this function returns a map from variable
    names to their corresponding checkpoint keys.
    If the checkpoint is name-based, this returns an empty dict.

  Raises:
    ValueError: If the object-based checkpoint is missing variables.
  """
  fname = checkpoint_utils._get_checkpoint_filename(path)  # pylint: disable=protected-access
  try:
    names_to_keys = saver_lib.object_graph_key_mapping(fname)
  except errors.NotFoundError:
    # If an error is raised from `object_graph_key_mapping`, then the
    # checkpoint is name-based. There are no renames, so return an empty dict.
    return {}

  missing_names = set(variable_names) - set(names_to_keys.keys())
  if missing_names:
    raise ValueError(
        "Attempting to warm-start from an object-based checkpoint, but found "
        "that the checkpoint did not contain values for all variables. The "
        "following variables were missing: {}"
        .format(missing_names))
  return {name: names_to_keys[name] for name in variable_names}
예제 #2
0
def _warm_start_var_with_vocab(var,
                               current_vocab_path,
                               current_vocab_size,
                               prev_ckpt,
                               prev_vocab_path,
                               previous_vocab_size=-1,
                               current_oov_buckets=0,
                               prev_tensor_name=None,
                               initializer=None,
                               axis=0):
    """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Use this method when the `var` is backed by vocabulary. This method stitches
  the given `var` such that values corresponding to individual features in the
  vocabulary remain consistent irrespective of changing order of the features
  between old and new vocabularies.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    current_vocab_path: Path to the vocab file used for the given `var`.
    current_vocab_size: An `int` specifying the number of entries in the current
      vocab.
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    previous_vocab_size: If provided, will constrain previous vocab to the first
      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
      buckets used for given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
    initializer: Variable initializer to be used for missing entries.  If None,
      missing entries will be zero-initialized.
    axis: Axis of the variable that the provided vocabulary corresponds to.

  Raises:
    ValueError: If required args are not provided.
  """
    if not (current_vocab_path and current_vocab_size and prev_ckpt
            and prev_vocab_path):
        raise ValueError(
            "Invalid args: Must provide all of [current_vocab_path, "
            "current_vocab_size, prev_ckpt, prev_vocab_path}.")
    if checkpoint_utils._is_variable(var):
        var = [var]
    elif (isinstance(var, list)
          and all(checkpoint_utils._is_variable(v) for v in var)):
        var = var
    elif isinstance(var, variables_lib.PartitionedVariable):
        var = var._get_variable_list()
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, list of Variable or "
            "PartitionedVariable, but is {}".format(type(var)))

    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = _infer_var_name(var)

    # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
    total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var)
    for v in var:
        v_shape = v.get_shape().as_list()
        slice_info = v._get_save_slice_info()
        partition_info = None
        if slice_info:
            partition_info = variable_scope._PartitionInfo(
                full_shape=slice_info.full_shape,
                var_offset=slice_info.var_offset)

        if axis == 0:
            new_row_vocab_size = current_vocab_size
            new_col_vocab_size = v_shape[1]
            old_row_vocab_size = previous_vocab_size
            old_row_vocab_file = prev_vocab_path
            new_row_vocab_file = current_vocab_path
            old_col_vocab_file = None
            new_col_vocab_file = None
            num_row_oov_buckets = current_oov_buckets
            num_col_oov_buckets = 0
        elif axis == 1:
            # Note that we must compute this value across all partitions, whereas
            # in the axis = 0 case, we can simply use v_shape[1] because we don't
            # allow partitioning across axis = 1.
            new_row_vocab_size = total_v_first_axis
            new_col_vocab_size = current_vocab_size
            old_row_vocab_size = -1
            old_row_vocab_file = None
            new_row_vocab_file = None
            old_col_vocab_file = prev_vocab_path
            new_col_vocab_file = current_vocab_path
            num_row_oov_buckets = 0
            num_col_oov_buckets = current_oov_buckets
        else:
            raise ValueError(
                "The only supported values for the axis argument are 0 "
                "and 1.  Provided axis: {}".format(axis))

        init = checkpoint_ops._load_and_remap_matrix_initializer(
            ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
            old_tensor_name=prev_tensor_name,
            new_row_vocab_size=new_row_vocab_size,
            new_col_vocab_size=new_col_vocab_size,
            old_row_vocab_size=old_row_vocab_size,
            old_row_vocab_file=old_row_vocab_file,
            new_row_vocab_file=new_row_vocab_file,
            old_col_vocab_file=old_col_vocab_file,
            new_col_vocab_file=new_col_vocab_file,
            num_row_oov_buckets=num_row_oov_buckets,
            num_col_oov_buckets=num_col_oov_buckets,
            initializer=initializer)
        new_init_val = ops.convert_to_tensor(
            init(shape=v_shape, partition_info=partition_info))
        v._initializer_op = state_ops.assign(v, new_init_val)
예제 #3
0
def _warm_start_var_with_vocab(var,
                               current_vocab_path,
                               current_vocab_size,
                               prev_ckpt,
                               prev_vocab_path,
                               previous_vocab_size=-1,
                               current_oov_buckets=0,
                               prev_tensor_name=None,
                               initializer=None,
                               axis=0):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Use this method when the `var` is backed by vocabulary. This method stitches
  the given `var` such that values corresponding to individual features in the
  vocabulary remain consistent irrespective of changing order of the features
  between old and new vocabularies.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    current_vocab_path: Path to the vocab file used for the given `var`.
    current_vocab_size: An `int` specifying the number of entries in the current
      vocab.
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    previous_vocab_size: If provided, will constrain previous vocab to the first
      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
      buckets used for given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
    initializer: Variable initializer to be used for missing entries.  If None,
      missing entries will be zero-initialized.
    axis: Axis of the variable that the provided vocabulary corresponds to.

  Raises:
    ValueError: If required args are not provided.
  """
  if not (current_vocab_path and current_vocab_size and prev_ckpt and
          prev_vocab_path):
    raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
                     "current_vocab_size, prev_ckpt, prev_vocab_path}.")
  if checkpoint_utils._is_variable(var):
    var = [var]
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):
    var = var
  elif isinstance(var, variables_lib.PartitionedVariable):
    var = var._get_variable_list()
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))

  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = _infer_var_name(var)

  # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
  total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
  for v in var:
    v_shape = v.get_shape().as_list()
    slice_info = v._get_save_slice_info()
    partition_info = None
    if slice_info:
      partition_info = variable_scope._PartitionInfo(
          full_shape=slice_info.full_shape,
          var_offset=slice_info.var_offset)

    if axis == 0:
      new_row_vocab_size = current_vocab_size
      new_col_vocab_size = v_shape[1]
      old_row_vocab_size = previous_vocab_size
      old_row_vocab_file = prev_vocab_path
      new_row_vocab_file = current_vocab_path
      old_col_vocab_file = None
      new_col_vocab_file = None
      num_row_oov_buckets = current_oov_buckets
      num_col_oov_buckets = 0
    elif axis == 1:
      # Note that we must compute this value across all partitions, whereas
      # in the axis = 0 case, we can simply use v_shape[1] because we don't
      # allow partitioning across axis = 1.
      new_row_vocab_size = total_v_first_axis
      new_col_vocab_size = current_vocab_size
      old_row_vocab_size = -1
      old_row_vocab_file = None
      new_row_vocab_file = None
      old_col_vocab_file = prev_vocab_path
      new_col_vocab_file = current_vocab_path
      num_row_oov_buckets = 0
      num_col_oov_buckets = current_oov_buckets
    else:
      raise ValueError("The only supported values for the axis argument are 0 "
                       "and 1.  Provided axis: {}".format(axis))

    init = checkpoint_ops._load_and_remap_matrix_initializer(
        ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
        old_tensor_name=prev_tensor_name,
        new_row_vocab_size=new_row_vocab_size,
        new_col_vocab_size=new_col_vocab_size,
        old_row_vocab_size=old_row_vocab_size,
        old_row_vocab_file=old_row_vocab_file,
        new_row_vocab_file=new_row_vocab_file,
        old_col_vocab_file=old_col_vocab_file,
        new_col_vocab_file=new_col_vocab_file,
        num_row_oov_buckets=num_row_oov_buckets,
        num_col_oov_buckets=num_col_oov_buckets,
        initializer=initializer)
    new_init_val = ops.convert_to_tensor(
        init(shape=v_shape, partition_info=partition_info))
    v._initializer_op = state_ops.assign(v, new_init_val)
def warm_start(ckpt_to_initialize_from,
               vars_to_warm_start=".*",
               var_name_to_prev_var_name=None):
  """Warm-starts de.Variable using the given settings.

    Args:
      ckpt_to_initialize_from: [Required] A string specifying the directory with
        checkpoint file(s) or path to checkpoint from which to warm-start the
        model parameters.
      vars_to_warm_start: [Optional] One of the following:
        - A regular expression (string) that captures which variables to
          warm-start (see tf.compat.v1.get_collection).  This expression will only
          consider variables in the TRAINABLE_VARIABLES collection -- if you need
          to warm-start non_TRAINABLE vars (such as optimizer accumulators or
          batch norm statistics), please use the below option.
        - A list of strings, each a regex scope provided to
          tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
          tf.compat.v1.get_collection).  For backwards compatibility reasons,
          this is separate from the single-string argument type.
        - A list of Variables to warm-start.  If you do not have access to the
          `Variable` objects at the call site, please use the above option.
        - `None`, in which case only TRAINABLE variables specified in
          `var_name_to_vocab_info` will be warm-started.
        Defaults to `'.*'`, which warm-starts all variables in the
        TRAINABLE_VARIABLES collection.  Note that this excludes variables such
        as accumulators and moving statistics from batch norm.

    Raises:
      ValueError: If saveable's spec.name not match pattern 
        defined by de.Variable._make_name.
    """

  def _replace_var_in_spec_name(spec_name, var_name):

    def _replace(m):
      return '{}_mht_{}of{}'.format(var_name, m.groups()[1], m.groups()[2])

    out = re.sub(r'(\w+)_mht_(\d+)of(\d+)', _replace, spec_name)
    if out is None:
      raise ValueError(
          "Invalid sepc name, should match `{}_mht_{}of{}`, given %s" %
          spec_name)
    return out

  logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from))

  de_variables = _get_de_variables(vars_to_warm_start)
  if not var_name_to_prev_var_name:
    var_name_to_prev_var_name = {}

  ckpt_file = checkpoint_utils._get_checkpoint_filename(ckpt_to_initialize_from)
  assign_ops = []
  for variable in de_variables:
    var_name = variable.name
    prev_var_name = var_name_to_prev_var_name.get(var_name)
    if prev_var_name:
      logging.debug("Warm-start variable: {}: prev_var_name: {}".format(
          var_name, prev_var_name or "Unchanged"))
    else:
      prev_var_name = var_name

    saveables = saveable_object_util.validate_and_slice_inputs([variable])
    for saveable in saveables:
      restore_specs = []
      for spec in saveable.specs:
        restore_specs.append((_replace_var_in_spec_name(spec.name,
                                                        prev_var_name),
                              spec.slice_spec, spec.dtype))

      names, slices, dtypes = zip(*restore_specs)
      # Load tensors in cuckoo_hashtable op's device
      with ops.colocate_with(saveable.op._resource_handle.op):
        saveable_tensors = io_ops.restore_v2(ckpt_file, names, slices, dtypes)
        assign_ops.append(saveable.restore(saveable_tensors, None))

  return control_flow_ops.group(assign_ops)
예제 #5
0
def _warmstart_var_with_vocab(var,
                              current_vocab_path,
                              current_vocab_size,
                              prev_ckpt,
                              prev_vocab_path,
                              previous_vocab_size=-1,
                              current_oov_buckets=0,
                              prev_tensor_name=None,
                              initializer=None):
    """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Use this method when the `var` is backed by vocabulary. This method stitches
  the given `var` such that values corresponding to individual features in the
  vocabulary remain consistent irrespective of changing order of the features
  between old and new vocabularies.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    current_vocab_path: Path to the vocab file used for the given `var`.
    current_vocab_size: An `int` specifying the number of entries in the current
      vocab.
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    previous_vocab_size: If provided, will constrain previous vocab to the first
      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
      buckets used for given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
    initializer: Variable initializer to be used for missing entries.  If None,
      missing entries will be zero-initialized.

  Raises:
    ValueError: If required args are not provided.
  """
    if not (current_vocab_path and current_vocab_size and prev_ckpt
            and prev_vocab_path):
        raise ValueError(
            "Invalid args: Must provide all of [current_vocab_path, "
            "current_vocab_size, prev_ckpt, prev_vocab_path}.")
    if _is_variable(var):
        var = [var]
    elif isinstance(var, list) and all(_is_variable(v) for v in var):
        var = var
    elif isinstance(var, variables_lib.PartitionedVariable):
        var = var._get_variable_list()
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, list of Variable or "
            "PartitionedVariable, but is {}".format(type(var)))

    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = _infer_var_name(var)

    for v in var:
        v_shape = v.get_shape().as_list()
        slice_info = v._get_save_slice_info()
        partition_info = None
        if slice_info:
            partition_info = variable_scope._PartitionInfo(
                full_shape=slice_info.full_shape,
                var_offset=slice_info.var_offset)

        # TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
        # remapping too.
        init = checkpoint_ops._load_and_remap_matrix_initializer(
            ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
            old_tensor_name=prev_tensor_name,
            new_row_vocab_size=current_vocab_size,
            new_col_vocab_size=v_shape[1],
            old_row_vocab_size=previous_vocab_size,
            old_row_vocab_file=prev_vocab_path,
            new_row_vocab_file=current_vocab_path,
            old_col_vocab_file=None,
            new_col_vocab_file=None,
            num_row_oov_buckets=current_oov_buckets,
            num_col_oov_buckets=0,
            initializer=initializer)
        new_init_val = ops.convert_to_tensor(
            init(shape=v_shape, partition_info=partition_info))
        v._initializer_op = state_ops.assign(v, new_init_val)
예제 #6
0
def _warmstart_var_with_vocab(var,
                              current_vocab_path,
                              current_vocab_size,
                              prev_ckpt,
                              prev_vocab_path,
                              previous_vocab_size=-1,
                              current_oov_buckets=0,
                              prev_tensor_name=None,
                              initializer=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Use this method when the `var` is backed by vocabulary. This method stitches
  the given `var` such that values corresponding to individual features in the
  vocabulary remain consistent irrespective of changing order of the features
  between old and new vocabularies.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    current_vocab_path: Path to the vocab file used for the given `var`.
    current_vocab_size: An `int` specifying the number of entries in the current
      vocab.
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    previous_vocab_size: If provided, will constrain previous vocab to the first
      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
      buckets used for given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
    initializer: Variable initializer to be used for missing entries.  If None,
      missing entries will be zero-initialized.

  Raises:
    ValueError: If required args are not provided.
  """
  if not (current_vocab_path and current_vocab_size and prev_ckpt and
          prev_vocab_path):
    raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
                     "current_vocab_size, prev_ckpt, prev_vocab_path}.")
  if _is_variable(var):
    var = [var]
  elif isinstance(var, list) and all(_is_variable(v) for v in var):
    var = var
  elif isinstance(var, variables_lib.PartitionedVariable):
    var = var._get_variable_list()
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))

  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = _infer_var_name(var)

  for v in var:
    v_shape = v.get_shape().as_list()
    slice_info = v._get_save_slice_info()
    partition_info = None
    if slice_info:
      partition_info = variable_scope._PartitionInfo(
          full_shape=slice_info.full_shape,
          var_offset=slice_info.var_offset)

    # TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
    # remapping too.
    init = checkpoint_ops._load_and_remap_matrix_initializer(
        ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
        old_tensor_name=prev_tensor_name,
        new_row_vocab_size=current_vocab_size,
        new_col_vocab_size=v_shape[1],
        old_row_vocab_size=previous_vocab_size,
        old_row_vocab_file=prev_vocab_path,
        new_row_vocab_file=current_vocab_path,
        old_col_vocab_file=None,
        new_col_vocab_file=None,
        num_row_oov_buckets=current_oov_buckets,
        num_col_oov_buckets=0,
        initializer=initializer)
    new_init_val = ops.convert_to_tensor(
        init(shape=v_shape, partition_info=partition_info))
    v._initializer_op = state_ops.assign(v, new_init_val)