def _get_var_info(var, prev_tensor_name=None): """Helper method for standarizing Variable and naming. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) list of `Variable`: The list must contain slices of the same larger variable. (iv) `PartitionedVariable` prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. Returns: A tuple of the Tensor name and var. """ if checkpoint_utils._is_variable(var): # pylint: disable=protected-access current_var_name = _infer_var_name([var]) elif (isinstance(var, list) and all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access current_var_name = _infer_var_name(var) elif isinstance(var, variables_lib.PartitionedVariable): current_var_name = _infer_var_name([var]) var = var._get_variable_list() # pylint: disable=protected-access else: raise TypeError( "var MUST be one of the following: a Variable, list of Variable or " "PartitionedVariable, but is {}".format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = current_var_name return prev_tensor_name, var
def _warm_start_var(var, prev_ckpt, prev_tensor_name=None): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) list of `Variable`: The list must contain slices of the same larger variable. (iv) `PartitionedVariable` prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. """ if checkpoint_utils._is_variable(var): # pylint: disable=protected-access current_var_name = _infer_var_name([var]) elif (isinstance(var, list) and all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access current_var_name = _infer_var_name(var) elif isinstance(var, variables_lib.PartitionedVariable): current_var_name = _infer_var_name([var]) var = var._get_variable_list() # pylint: disable=protected-access else: raise TypeError( "var MUST be one of the following: a Variable, list of Variable or " "PartitionedVariable, but is {}".format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = current_var_name checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def _separate_vars_and_devars_to_warm_start(vars_to_warm_start): _vars_to_warm_start = [] _devars_to_warm_start = [] if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: _vars_to_warm_start = vars_to_warm_start _devars_to_warm_start = [] elif all(isinstance(v, str) for v in vars_to_warm_start): # It's ok to add redundant var str to each vars_to_warm_start, which is # consistent with the original warm starting behavior. _vars_to_warm_start = [v for v in vars_to_warm_start] _devars_to_warm_start = [v for v in vars_to_warm_start] elif all(checkpoint_utils._is_variable(v) or checkpoint_utils._is_dynamic_embedding_variable(v) for v in vars_to_warm_start): _vars_to_warm_start = [v for v in vars_to_warm_start if checkpoint_utils._is_variable(v)] _devars_to_warm_start = [ v for v in vars_to_warm_start if checkpoint_utils._is_dynamic_embedding_variable(v) ] else: raise ValueError("If `vars_to_warm_start` is a list, it must be all " "`Variable`, all `dynamic_embedding_ops.Variable` or all " "`str`. Given types are {}".format( [type(v) for v in vars_to_warm_start])) return _vars_to_warm_start, _devars_to_warm_start
def _get_grouped_variables(vars_to_warm_start): """Collects and groups (possibly partitioned) variables into a dictionary. The variables can be provided explicitly through vars_to_warm_start, or they are retrieved from collections (see below). Args: vars_to_warm_start: One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.compat.v1.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection. - A list of strings, each representing a full variable name to warm-start. These will consider variables in GLOBAL_VARIABLES collection. - A list of Variables to warm-start. - `None`, in which case all variables in TRAINABLE_VARIABLES will be used. Returns: A dictionary mapping variable names (strings) to lists of Variables. Raises: ValueError: If vars_to_warm_start is not a string, `None`, a list of `Variables`, or a list of strings. """ # TODO(b/143899805): Remove unicode checks when deprecating Python2. if isinstance(vars_to_warm_start, six.string_types) or vars_to_warm_start is None: # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match # everything (in TRAINABLE_VARIABLES) here. logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.") list_of_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start) elif isinstance(vars_to_warm_start, list): if all(isinstance(v, six.string_types) for v in vars_to_warm_start): list_of_vars = [] for v in vars_to_warm_start: list_of_vars += ops.get_collection( ops.GraphKeys.GLOBAL_VARIABLES, scope=v) elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access list_of_vars = vars_to_warm_start else: raise ValueError( "If `vars_to_warm_start` is a list, it must be all " "`Variable` or all `str`. Given types are {}".format( [type(v) for v in vars_to_warm_start])) else: raise ValueError( "`vars_to_warm_start must be a `list` or `str`. Given " "type is {}".format(type(vars_to_warm_start))) # We have to deal with partitioned variables, since get_collection flattens # out the list. grouped_variables = {} for v in list_of_vars: if not isinstance(v, list): var_name = _infer_var_name([v]) else: var_name = _infer_var_name(v) grouped_variables.setdefault(var_name, []).append(v) return grouped_variables
def _get_grouped_variables(vars_to_warm_start): """Collects and groups (possibly partitioned) variables into a dictionary. The variables can be provided explicitly through vars_to_warm_start, or they are retrieved from collections (see below). Args: vars_to_warm_start: One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection. - A list of Variables to warm-start. - A list of strings, each representing a full variable name to warm-start. - `None`, in which case only variables specified in `var_name_to_vocab_info` will be warm-started. Returns: A dictionary mapping variable names (strings) to lists of Variables. Raises: ValueError: If vars_to_warm_start is not a string, `None`, a list of `Variables`, or a list of strings. """ if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match # everything (in TRAINABLE_VARIABLES) here. list_of_vars = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start) elif isinstance(vars_to_warm_start, list): if all([isinstance(v, str) for v in vars_to_warm_start]): list_of_vars = [] for v in vars_to_warm_start: list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=v) elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]): # pylint: disable=protected-access list_of_vars = vars_to_warm_start else: raise ValueError("If `vars_to_warm_start` is a list, it must be all " "`Variable` or all `str`. Given types are {}".format( [type(v) for v in vars_to_warm_start])) else: raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " "type is {}".format(type(vars_to_warm_start))) # We have to deal with partitioned variables, since get_collection flattens # out the list. grouped_variables = {} for v in list_of_vars: if not isinstance(v, list): var_name = _infer_var_name([v]) else: var_name = _infer_var_name(v) grouped_variables.setdefault(var_name, []).append(v) return grouped_variables
def _warm_start_var_with_vocab(var, current_vocab_path, current_vocab_size, prev_ckpt, prev_vocab_path, previous_vocab_size=-1, current_oov_buckets=0, prev_tensor_name=None, initializer=None, axis=0): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Use this method when the `var` is backed by vocabulary. This method stitches the given `var` such that values corresponding to individual features in the vocabulary remain consistent irrespective of changing order of the features between old and new vocabularies. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) list of `Variable`: The list must contain slices of the same larger variable. (iv) `PartitionedVariable` current_vocab_path: Path to the vocab file used for the given `var`. current_vocab_size: An `int` specifying the number of entries in the current vocab. prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`. previous_vocab_size: If provided, will constrain previous vocab to the first `previous_vocab_size` entries. -1 means use the entire previous vocab. current_oov_buckets: An `int` specifying the number of out-of-vocabulary buckets used for given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. initializer: Variable initializer to be used for missing entries. If None, missing entries will be zero-initialized. axis: Axis of the variable that the provided vocabulary corresponds to. Raises: ValueError: If required args are not provided. """ if not (current_vocab_path and current_vocab_size and prev_ckpt and prev_vocab_path): raise ValueError( "Invalid args: Must provide all of [current_vocab_path, " "current_vocab_size, prev_ckpt, prev_vocab_path}.") if checkpoint_utils._is_variable(var): var = [var] elif (isinstance(var, list) and all(checkpoint_utils._is_variable(v) for v in var)): var = var elif isinstance(var, variables_lib.PartitionedVariable): var = var._get_variable_list() else: raise TypeError( "var MUST be one of the following: a Variable, list of Variable or " "PartitionedVariable, but is {}".format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = _infer_var_name(var) # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases). total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var) for v in var: v_shape = v.get_shape().as_list() slice_info = v._get_save_slice_info() partition_info = None if slice_info: partition_info = variable_scope._PartitionInfo( full_shape=slice_info.full_shape, var_offset=slice_info.var_offset) if axis == 0: new_row_vocab_size = current_vocab_size new_col_vocab_size = v_shape[1] old_row_vocab_size = previous_vocab_size old_row_vocab_file = prev_vocab_path new_row_vocab_file = current_vocab_path old_col_vocab_file = None new_col_vocab_file = None num_row_oov_buckets = current_oov_buckets num_col_oov_buckets = 0 elif axis == 1: # Note that we must compute this value across all partitions, whereas # in the axis = 0 case, we can simply use v_shape[1] because we don't # allow partitioning across axis = 1. new_row_vocab_size = total_v_first_axis new_col_vocab_size = current_vocab_size old_row_vocab_size = -1 old_row_vocab_file = None new_row_vocab_file = None old_col_vocab_file = prev_vocab_path new_col_vocab_file = current_vocab_path num_row_oov_buckets = 0 num_col_oov_buckets = current_oov_buckets else: raise ValueError( "The only supported values for the axis argument are 0 " "and 1. Provided axis: {}".format(axis)) init = checkpoint_ops._load_and_remap_matrix_initializer( ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, new_row_vocab_size=new_row_vocab_size, new_col_vocab_size=new_col_vocab_size, old_row_vocab_size=old_row_vocab_size, old_row_vocab_file=old_row_vocab_file, new_row_vocab_file=new_row_vocab_file, old_col_vocab_file=old_col_vocab_file, new_col_vocab_file=new_col_vocab_file, num_row_oov_buckets=num_row_oov_buckets, num_col_oov_buckets=num_col_oov_buckets, initializer=initializer) new_init_val = ops.convert_to_tensor( init(shape=v_shape, partition_info=partition_info)) v._initializer_op = state_ops.assign(v, new_init_val)
def _warm_start_var_with_vocab(var, current_vocab_path, current_vocab_size, prev_ckpt, prev_vocab_path, previous_vocab_size=-1, current_oov_buckets=0, prev_tensor_name=None, initializer=None, axis=0): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Use this method when the `var` is backed by vocabulary. This method stitches the given `var` such that values corresponding to individual features in the vocabulary remain consistent irrespective of changing order of the features between old and new vocabularies. Args: var: Current graph's variable that needs to be warm-started (initialized). Can be either of the following: (i) `Variable` (ii) `ResourceVariable` (iii) list of `Variable`: The list must contain slices of the same larger variable. (iv) `PartitionedVariable` current_vocab_path: Path to the vocab file used for the given `var`. current_vocab_size: An `int` specifying the number of entries in the current vocab. prev_ckpt: A string specifying the directory with checkpoint file(s) or path to checkpoint. The given checkpoint must have tensor with name `prev_tensor_name` (if not None) or tensor with name same as given `var`. prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`. previous_vocab_size: If provided, will constrain previous vocab to the first `previous_vocab_size` entries. -1 means use the entire previous vocab. current_oov_buckets: An `int` specifying the number of out-of-vocabulary buckets used for given `var`. prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. initializer: Variable initializer to be used for missing entries. If None, missing entries will be zero-initialized. axis: Axis of the variable that the provided vocabulary corresponds to. Raises: ValueError: If required args are not provided. """ if not (current_vocab_path and current_vocab_size and prev_ckpt and prev_vocab_path): raise ValueError("Invalid args: Must provide all of [current_vocab_path, " "current_vocab_size, prev_ckpt, prev_vocab_path}.") if checkpoint_utils._is_variable(var): var = [var] elif (isinstance(var, list) and all(checkpoint_utils._is_variable(v) for v in var)): var = var elif isinstance(var, variables_lib.PartitionedVariable): var = var._get_variable_list() else: raise TypeError( "var MUST be one of the following: a Variable, list of Variable or " "PartitionedVariable, but is {}".format(type(var))) if not prev_tensor_name: # Assume tensor name remains the same. prev_tensor_name = _infer_var_name(var) # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases). total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var]) for v in var: v_shape = v.get_shape().as_list() slice_info = v._get_save_slice_info() partition_info = None if slice_info: partition_info = variable_scope._PartitionInfo( full_shape=slice_info.full_shape, var_offset=slice_info.var_offset) if axis == 0: new_row_vocab_size = current_vocab_size new_col_vocab_size = v_shape[1] old_row_vocab_size = previous_vocab_size old_row_vocab_file = prev_vocab_path new_row_vocab_file = current_vocab_path old_col_vocab_file = None new_col_vocab_file = None num_row_oov_buckets = current_oov_buckets num_col_oov_buckets = 0 elif axis == 1: # Note that we must compute this value across all partitions, whereas # in the axis = 0 case, we can simply use v_shape[1] because we don't # allow partitioning across axis = 1. new_row_vocab_size = total_v_first_axis new_col_vocab_size = current_vocab_size old_row_vocab_size = -1 old_row_vocab_file = None new_row_vocab_file = None old_col_vocab_file = prev_vocab_path new_col_vocab_file = current_vocab_path num_row_oov_buckets = 0 num_col_oov_buckets = current_oov_buckets else: raise ValueError("The only supported values for the axis argument are 0 " "and 1. Provided axis: {}".format(axis)) init = checkpoint_ops._load_and_remap_matrix_initializer( ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, new_row_vocab_size=new_row_vocab_size, new_col_vocab_size=new_col_vocab_size, old_row_vocab_size=old_row_vocab_size, old_row_vocab_file=old_row_vocab_file, new_row_vocab_file=new_row_vocab_file, old_col_vocab_file=old_col_vocab_file, new_col_vocab_file=new_col_vocab_file, num_row_oov_buckets=num_row_oov_buckets, num_col_oov_buckets=num_col_oov_buckets, initializer=initializer) new_init_val = ops.convert_to_tensor( init(shape=v_shape, partition_info=partition_info)) v._initializer_op = state_ops.assign(v, new_init_val)