Example #1
0
def var_and_name_scope(names):
    """Creates a variable scope and a name scope.

  If a variable_scope is provided, this will reenter that variable scope.
  However, if none is provided then the variable scope will match the generated
  part of the name scope.

  Args:
    names: A tuple of name_scope, variable_scope or None.
  Yields:
    The result of name_scope and variable_scope as a tuple.
  """
    # pylint: disable=protected-access
    if not names:
        yield None, None
    else:
        name, var_scope = names
        with tf.name_scope(name) as scope:
            # TODO(eiderman): This is a workaround until the variable_scope updates land
            # in a TF release.
            old_vs = tf.get_variable_scope()
            if var_scope is None:
                count = len(name.split("/"))
                scoped_name = "/".join(scope.split("/")[-count - 1 : -1])
                full_name = (old_vs.name + "/" + scoped_name).lstrip("/")
            else:
                full_name = var_scope.name

            vs_key = tf.get_collection(variable_scope._VARSCOPE_KEY)
            try:
                vs_key[0] = variable_scope._VariableScope(old_vs.reuse, name=full_name, initializer=old_vs.initializer)
                vs_key[0].name_scope = scope
                yield scope, vs_key[0]
            finally:
                vs_key[0] = old_vs
Example #2
0
File: util.py Project: hans/rlcomp
def track_model_updates(main_name, track_name, tau):
  """
  Build an update op to make parameters of a tracking model follow a main model.

  Call outside of the scope of both the main and tracking model.

  Returns:
    A group of `tf.assign` ops which require no inputs (only parameter values).
  """

  updates = []
  params = [var for var in tf.all_variables()
            if var.name.startswith(main_name + "/")]

  for param in params:
    track_param_name = param.op.name.replace(main_name + "/",
                                             track_name + "/")
    with tf.variable_scope(_VariableScope(True), reuse=True):
      try:
        track_param = tf.get_variable(track_param_name)
      except ValueError:
        logging.warn("Tracking model variable %s does not exist",
                     track_param_name)
        continue

    # TODO sparse params
    update_op = tf.assign(track_param,
                          tau * param + (1 - tau) * track_param)
    updates.append(update_op)

  return tf.group(*updates)
Example #3
0
def var_and_name_scope(names):
    """Creates a variable scope and a name scope.

  If a variable_scope is provided, this will reenter that variable scope.
  However, if none is provided then the variable scope will match the generated
  part of the name scope.

  Args:
    names: A tuple of name_scope, variable_scope or None.
  Yields:
    The result of name_scope and variable_scope as a tuple.
  """
    # pylint: disable=protected-access
    if not names:
        yield None, None
    else:
        name, var_scope = names
        with tf.name_scope(name) as scope:
            # TODO(eiderman): This is a workaround until the variable_scope updates land
            # in a TF release.
            old_vs = tf.get_variable_scope()
            if var_scope is None:
                count = len(name.split('/'))
                scoped_name = '/'.join(scope.split('/')[-count - 1:-1])
                full_name = (old_vs.name + '/' + scoped_name).lstrip('/')
            else:
                full_name = var_scope.name

            vs_key = tf.get_collection_ref(variable_scope._VARSCOPESTORE_KEY)
            try:
                # TODO(eiderman): Remove this hack or fix the full file.
                try:
                    vs_key[0] = tf.VariableScope(
                        old_vs.reuse,
                        name=full_name,
                        initializer=old_vs.initializer,
                        regularizer=old_vs.regularizer,
                        caching_device=old_vs.caching_device)
                except AttributeError:
                    vs_key[0] = variable_scope._VariableScope(
                        old_vs.reuse,
                        name=full_name,
                        initializer=old_vs.initializer)

                vs_key[0].name_scope = scope
                yield scope, vs_key[0]
            finally:
                vs_key[0] = old_vs
Example #4
0
File: util.py Project: hans/rlcomp
def match_variable(name, scope_name):
  """
  Match a variable (initialize with same value) from another variable scope.

  After initialization, the values of the two variables are not tied in any
  way.
  """

  # HACK: Using private _VariableScope API in order to be able to get an
  # absolute-path to the given variable scope name (i.e., not have it treated
  # as a relative path and placed under whatever variable scope might contain
  # this function call)
  with tf.variable_scope(_VariableScope(True, scope_name), reuse=True):
    track_var = tf.get_variable(name)

  # Create a dummy initializer.
  initializer = lambda *args, **kwargs: track_var.initialized_value()

  return tf.get_variable(name, shape=track_var.get_shape(),
                         initializer=initializer)