Example #1
0
def _get_var_info(var, prev_tensor_name=None):
    """Helper method for standarizing Variable and naming.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following: (i) `Variable` (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable. (iv) `PartitionedVariable`
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.

  Returns:
    A tuple of the Tensor name and var.
  """
    if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
        current_var_name = _infer_var_name([var])
    elif (isinstance(var, list)
          and all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
        current_var_name = _infer_var_name(var)
    elif isinstance(var, variables_lib.PartitionedVariable):
        current_var_name = _infer_var_name([var])
        var = var._get_variable_list()  # pylint: disable=protected-access
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, list of Variable or "
            "PartitionedVariable, but is {}".format(type(var)))
    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = current_var_name

    return prev_tensor_name, var
def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
  """
  if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
    current_var_name = _infer_var_name([var])
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
    current_var_name = _infer_var_name(var)
  elif isinstance(var, variables_lib.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name
  checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
Example #3
0
def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
    """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
  """
    if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
        current_var_name = _infer_var_name([var])
    elif (isinstance(var, list)
          and all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
        current_var_name = _infer_var_name(var)
    elif isinstance(var, variables_lib.PartitionedVariable):
        current_var_name = _infer_var_name([var])
        var = var._get_variable_list()  # pylint: disable=protected-access
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, list of Variable or "
            "PartitionedVariable, but is {}".format(type(var)))
    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = current_var_name
    checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def _get_var_info(var, prev_tensor_name=None):
  """Helper method for standarizing Variable and naming.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following: (i) `Variable` (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable. (iv) `PartitionedVariable`
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.

  Returns:
    A tuple of the Tensor name and var.
  """
  if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
    current_var_name = _infer_var_name([var])
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
    current_var_name = _infer_var_name(var)
  elif isinstance(var, variables_lib.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name

  return prev_tensor_name, var
def _separate_vars_and_devars_to_warm_start(vars_to_warm_start):
  _vars_to_warm_start = []
  _devars_to_warm_start = []
  if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
    _vars_to_warm_start = vars_to_warm_start
    _devars_to_warm_start = []
  elif all(isinstance(v, str) for v in vars_to_warm_start):
    # It's ok to add redundant var str to each vars_to_warm_start, which is
    # consistent with the original warm starting behavior.
    _vars_to_warm_start = [v for v in vars_to_warm_start]
    _devars_to_warm_start = [v for v in vars_to_warm_start]
  elif all(checkpoint_utils._is_variable(v)
           or checkpoint_utils._is_dynamic_embedding_variable(v)
           for v in vars_to_warm_start):
    _vars_to_warm_start = [v for v in vars_to_warm_start
                          if checkpoint_utils._is_variable(v)]
    _devars_to_warm_start = [
      v for v in vars_to_warm_start
      if checkpoint_utils._is_dynamic_embedding_variable(v)
    ]
  else:
    raise ValueError("If `vars_to_warm_start` is a list, it must be all "
                     "`Variable`, all `dynamic_embedding_ops.Variable` or all "
                     "`str`.  Given types are {}".format(
      [type(v) for v in vars_to_warm_start]))
  return _vars_to_warm_start, _devars_to_warm_start
Example #6
0
def _get_grouped_variables(vars_to_warm_start):
    """Collects and groups (possibly partitioned) variables into a dictionary.

  The variables can be provided explicitly through vars_to_warm_start, or they
  are retrieved from collections (see below).

  Args:
    vars_to_warm_start: One of the following:

      - A regular expression (string) that captures which variables to
        warm-start (see tf.compat.v1.get_collection).  This expression will
        only consider variables in the TRAINABLE_VARIABLES collection.
      - A list of strings, each representing a full variable name to warm-start.
        These will consider variables in GLOBAL_VARIABLES collection.
      - A list of Variables to warm-start.
      - `None`, in which case all variables in TRAINABLE_VARIABLES will be used.
  Returns:
    A dictionary mapping variable names (strings) to lists of Variables.
  Raises:
    ValueError: If vars_to_warm_start is not a string, `None`, a list of
      `Variables`, or a list of strings.
  """
    # TODO(b/143899805): Remove unicode checks when deprecating Python2.
    if isinstance(vars_to_warm_start,
                  six.string_types) or vars_to_warm_start is None:
        # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
        # everything (in TRAINABLE_VARIABLES) here.
        logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
        list_of_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
                                          scope=vars_to_warm_start)
    elif isinstance(vars_to_warm_start, list):
        if all(isinstance(v, six.string_types) for v in vars_to_warm_start):
            list_of_vars = []
            for v in vars_to_warm_start:
                list_of_vars += ops.get_collection(
                    ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
        elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start):  # pylint: disable=protected-access
            list_of_vars = vars_to_warm_start
        else:
            raise ValueError(
                "If `vars_to_warm_start` is a list, it must be all "
                "`Variable` or all `str`.  Given types are {}".format(
                    [type(v) for v in vars_to_warm_start]))
    else:
        raise ValueError(
            "`vars_to_warm_start must be a `list` or `str`.  Given "
            "type is {}".format(type(vars_to_warm_start)))
    # We have to deal with partitioned variables, since get_collection flattens
    # out the list.
    grouped_variables = {}
    for v in list_of_vars:
        if not isinstance(v, list):
            var_name = _infer_var_name([v])
        else:
            var_name = _infer_var_name(v)
        grouped_variables.setdefault(var_name, []).append(v)

    return grouped_variables
def _get_grouped_variables(vars_to_warm_start):
  """Collects and groups (possibly partitioned) variables into a dictionary.

  The variables can be provided explicitly through vars_to_warm_start, or they
  are retrieved from collections (see below).

  Args:
    vars_to_warm_start: One of the following:

      - A regular expression (string) that captures which variables to
        warm-start (see tf.get_collection).  This expression will only consider
        variables in the TRAINABLE_VARIABLES collection.
      - A list of Variables to warm-start.
      - A list of strings, each representing a full variable name to warm-start.
      - `None`, in which case only variables specified in
        `var_name_to_vocab_info` will be warm-started.
  Returns:
    A dictionary mapping variable names (strings) to lists of Variables.
  Raises:
    ValueError: If vars_to_warm_start is not a string, `None`, a list of
      `Variables`, or a list of strings.
  """
  if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
    # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
    # everything (in TRAINABLE_VARIABLES) here.
    list_of_vars = ops.get_collection(
        ops.GraphKeys.TRAINABLE_VARIABLES,
        scope=vars_to_warm_start)
  elif isinstance(vars_to_warm_start, list):
    if all([isinstance(v, str) for v in vars_to_warm_start]):
      list_of_vars = []
      for v in vars_to_warm_start:
        list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                           scope=v)
    elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]):  # pylint: disable=protected-access
      list_of_vars = vars_to_warm_start
    else:
      raise ValueError("If `vars_to_warm_start` is a list, it must be all "
                       "`Variable` or all `str`.  Given types are {}".format(
                           [type(v) for v in vars_to_warm_start]))
  else:
    raise ValueError("`vars_to_warm_start must be a `list` or `str`.  Given "
                     "type is {}".format(type(vars_to_warm_start)))
  # We have to deal with partitioned variables, since get_collection flattens
  # out the list.
  grouped_variables = {}
  for v in list_of_vars:
    if not isinstance(v, list):
      var_name = _infer_var_name([v])
    else:
      var_name = _infer_var_name(v)
    grouped_variables.setdefault(var_name, []).append(v)

  return grouped_variables
Example #8
0
def _warm_start_var_with_vocab(var,
                               current_vocab_path,
                               current_vocab_size,
                               prev_ckpt,
                               prev_vocab_path,
                               previous_vocab_size=-1,
                               current_oov_buckets=0,
                               prev_tensor_name=None,
                               initializer=None,
                               axis=0):
    """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Use this method when the `var` is backed by vocabulary. This method stitches
  the given `var` such that values corresponding to individual features in the
  vocabulary remain consistent irrespective of changing order of the features
  between old and new vocabularies.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    current_vocab_path: Path to the vocab file used for the given `var`.
    current_vocab_size: An `int` specifying the number of entries in the current
      vocab.
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    previous_vocab_size: If provided, will constrain previous vocab to the first
      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
      buckets used for given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
    initializer: Variable initializer to be used for missing entries.  If None,
      missing entries will be zero-initialized.
    axis: Axis of the variable that the provided vocabulary corresponds to.

  Raises:
    ValueError: If required args are not provided.
  """
    if not (current_vocab_path and current_vocab_size and prev_ckpt
            and prev_vocab_path):
        raise ValueError(
            "Invalid args: Must provide all of [current_vocab_path, "
            "current_vocab_size, prev_ckpt, prev_vocab_path}.")
    if checkpoint_utils._is_variable(var):
        var = [var]
    elif (isinstance(var, list)
          and all(checkpoint_utils._is_variable(v) for v in var)):
        var = var
    elif isinstance(var, variables_lib.PartitionedVariable):
        var = var._get_variable_list()
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, list of Variable or "
            "PartitionedVariable, but is {}".format(type(var)))

    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = _infer_var_name(var)

    # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
    total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var)
    for v in var:
        v_shape = v.get_shape().as_list()
        slice_info = v._get_save_slice_info()
        partition_info = None
        if slice_info:
            partition_info = variable_scope._PartitionInfo(
                full_shape=slice_info.full_shape,
                var_offset=slice_info.var_offset)

        if axis == 0:
            new_row_vocab_size = current_vocab_size
            new_col_vocab_size = v_shape[1]
            old_row_vocab_size = previous_vocab_size
            old_row_vocab_file = prev_vocab_path
            new_row_vocab_file = current_vocab_path
            old_col_vocab_file = None
            new_col_vocab_file = None
            num_row_oov_buckets = current_oov_buckets
            num_col_oov_buckets = 0
        elif axis == 1:
            # Note that we must compute this value across all partitions, whereas
            # in the axis = 0 case, we can simply use v_shape[1] because we don't
            # allow partitioning across axis = 1.
            new_row_vocab_size = total_v_first_axis
            new_col_vocab_size = current_vocab_size
            old_row_vocab_size = -1
            old_row_vocab_file = None
            new_row_vocab_file = None
            old_col_vocab_file = prev_vocab_path
            new_col_vocab_file = current_vocab_path
            num_row_oov_buckets = 0
            num_col_oov_buckets = current_oov_buckets
        else:
            raise ValueError(
                "The only supported values for the axis argument are 0 "
                "and 1.  Provided axis: {}".format(axis))

        init = checkpoint_ops._load_and_remap_matrix_initializer(
            ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
            old_tensor_name=prev_tensor_name,
            new_row_vocab_size=new_row_vocab_size,
            new_col_vocab_size=new_col_vocab_size,
            old_row_vocab_size=old_row_vocab_size,
            old_row_vocab_file=old_row_vocab_file,
            new_row_vocab_file=new_row_vocab_file,
            old_col_vocab_file=old_col_vocab_file,
            new_col_vocab_file=new_col_vocab_file,
            num_row_oov_buckets=num_row_oov_buckets,
            num_col_oov_buckets=num_col_oov_buckets,
            initializer=initializer)
        new_init_val = ops.convert_to_tensor(
            init(shape=v_shape, partition_info=partition_info))
        v._initializer_op = state_ops.assign(v, new_init_val)
def _warm_start_var_with_vocab(var,
                               current_vocab_path,
                               current_vocab_size,
                               prev_ckpt,
                               prev_vocab_path,
                               previous_vocab_size=-1,
                               current_oov_buckets=0,
                               prev_tensor_name=None,
                               initializer=None,
                               axis=0):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Use this method when the `var` is backed by vocabulary. This method stitches
  the given `var` such that values corresponding to individual features in the
  vocabulary remain consistent irrespective of changing order of the features
  between old and new vocabularies.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    current_vocab_path: Path to the vocab file used for the given `var`.
    current_vocab_size: An `int` specifying the number of entries in the current
      vocab.
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    previous_vocab_size: If provided, will constrain previous vocab to the first
      `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    current_oov_buckets: An `int` specifying the number of out-of-vocabulary
      buckets used for given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
    initializer: Variable initializer to be used for missing entries.  If None,
      missing entries will be zero-initialized.
    axis: Axis of the variable that the provided vocabulary corresponds to.

  Raises:
    ValueError: If required args are not provided.
  """
  if not (current_vocab_path and current_vocab_size and prev_ckpt and
          prev_vocab_path):
    raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
                     "current_vocab_size, prev_ckpt, prev_vocab_path}.")
  if checkpoint_utils._is_variable(var):
    var = [var]
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):
    var = var
  elif isinstance(var, variables_lib.PartitionedVariable):
    var = var._get_variable_list()
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))

  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = _infer_var_name(var)

  # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
  total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
  for v in var:
    v_shape = v.get_shape().as_list()
    slice_info = v._get_save_slice_info()
    partition_info = None
    if slice_info:
      partition_info = variable_scope._PartitionInfo(
          full_shape=slice_info.full_shape,
          var_offset=slice_info.var_offset)

    if axis == 0:
      new_row_vocab_size = current_vocab_size
      new_col_vocab_size = v_shape[1]
      old_row_vocab_size = previous_vocab_size
      old_row_vocab_file = prev_vocab_path
      new_row_vocab_file = current_vocab_path
      old_col_vocab_file = None
      new_col_vocab_file = None
      num_row_oov_buckets = current_oov_buckets
      num_col_oov_buckets = 0
    elif axis == 1:
      # Note that we must compute this value across all partitions, whereas
      # in the axis = 0 case, we can simply use v_shape[1] because we don't
      # allow partitioning across axis = 1.
      new_row_vocab_size = total_v_first_axis
      new_col_vocab_size = current_vocab_size
      old_row_vocab_size = -1
      old_row_vocab_file = None
      new_row_vocab_file = None
      old_col_vocab_file = prev_vocab_path
      new_col_vocab_file = current_vocab_path
      num_row_oov_buckets = 0
      num_col_oov_buckets = current_oov_buckets
    else:
      raise ValueError("The only supported values for the axis argument are 0 "
                       "and 1.  Provided axis: {}".format(axis))

    init = checkpoint_ops._load_and_remap_matrix_initializer(
        ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
        old_tensor_name=prev_tensor_name,
        new_row_vocab_size=new_row_vocab_size,
        new_col_vocab_size=new_col_vocab_size,
        old_row_vocab_size=old_row_vocab_size,
        old_row_vocab_file=old_row_vocab_file,
        new_row_vocab_file=new_row_vocab_file,
        old_col_vocab_file=old_col_vocab_file,
        new_col_vocab_file=new_col_vocab_file,
        num_row_oov_buckets=num_row_oov_buckets,
        num_col_oov_buckets=num_col_oov_buckets,
        initializer=initializer)
    new_init_val = ops.convert_to_tensor(
        init(shape=v_shape, partition_info=partition_info))
    v._initializer_op = state_ops.assign(v, new_init_val)