def var_and_name_scope(names): """Creates a variable scope and a name scope. If a variable_scope is provided, this will reenter that variable scope. However, if none is provided then the variable scope will match the generated part of the name scope. Args: names: A tuple of name_scope, variable_scope or None. Yields: The result of name_scope and variable_scope as a tuple. """ # pylint: disable=protected-access if not names: yield None, None else: name, var_scope = names with tf.name_scope(name) as scope: # TODO(eiderman): This is a workaround until the variable_scope updates land # in a TF release. old_vs = tf.get_variable_scope() if var_scope is None: count = len(name.split("/")) scoped_name = "/".join(scope.split("/")[-count - 1 : -1]) full_name = (old_vs.name + "/" + scoped_name).lstrip("/") else: full_name = var_scope.name vs_key = tf.get_collection(variable_scope._VARSCOPE_KEY) try: vs_key[0] = variable_scope._VariableScope(old_vs.reuse, name=full_name, initializer=old_vs.initializer) vs_key[0].name_scope = scope yield scope, vs_key[0] finally: vs_key[0] = old_vs
def track_model_updates(main_name, track_name, tau): """ Build an update op to make parameters of a tracking model follow a main model. Call outside of the scope of both the main and tracking model. Returns: A group of `tf.assign` ops which require no inputs (only parameter values). """ updates = [] params = [var for var in tf.all_variables() if var.name.startswith(main_name + "/")] for param in params: track_param_name = param.op.name.replace(main_name + "/", track_name + "/") with tf.variable_scope(_VariableScope(True), reuse=True): try: track_param = tf.get_variable(track_param_name) except ValueError: logging.warn("Tracking model variable %s does not exist", track_param_name) continue # TODO sparse params update_op = tf.assign(track_param, tau * param + (1 - tau) * track_param) updates.append(update_op) return tf.group(*updates)
def var_and_name_scope(names): """Creates a variable scope and a name scope. If a variable_scope is provided, this will reenter that variable scope. However, if none is provided then the variable scope will match the generated part of the name scope. Args: names: A tuple of name_scope, variable_scope or None. Yields: The result of name_scope and variable_scope as a tuple. """ # pylint: disable=protected-access if not names: yield None, None else: name, var_scope = names with tf.name_scope(name) as scope: # TODO(eiderman): This is a workaround until the variable_scope updates land # in a TF release. old_vs = tf.get_variable_scope() if var_scope is None: count = len(name.split('/')) scoped_name = '/'.join(scope.split('/')[-count - 1:-1]) full_name = (old_vs.name + '/' + scoped_name).lstrip('/') else: full_name = var_scope.name vs_key = tf.get_collection_ref(variable_scope._VARSCOPESTORE_KEY) try: # TODO(eiderman): Remove this hack or fix the full file. try: vs_key[0] = tf.VariableScope( old_vs.reuse, name=full_name, initializer=old_vs.initializer, regularizer=old_vs.regularizer, caching_device=old_vs.caching_device) except AttributeError: vs_key[0] = variable_scope._VariableScope( old_vs.reuse, name=full_name, initializer=old_vs.initializer) vs_key[0].name_scope = scope yield scope, vs_key[0] finally: vs_key[0] = old_vs
def match_variable(name, scope_name): """ Match a variable (initialize with same value) from another variable scope. After initialization, the values of the two variables are not tied in any way. """ # HACK: Using private _VariableScope API in order to be able to get an # absolute-path to the given variable scope name (i.e., not have it treated # as a relative path and placed under whatever variable scope might contain # this function call) with tf.variable_scope(_VariableScope(True, scope_name), reuse=True): track_var = tf.get_variable(name) # Create a dummy initializer. initializer = lambda *args, **kwargs: track_var.initialized_value() return tf.get_variable(name, shape=track_var.get_shape(), initializer=initializer)