Пример #1
0
    def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
        """Validates and registers the layer_key associated with the fisher_block.

    Args:
      layer_key: A variable or tuple of variables. The key to check for in
          existing registrations and to register if valid.
      fisher_block: The associated `FisherBlock`.
      reuse: Method to use for inserting new `FisherBlock`s. One of True, False,
        or 'VARIABLE_SCOPE'.

    Raises:
      ValueError: If `layer_key` was already registered and reuse is `False`,
        if `layer_key` was registered with a different block type, or if
        `layer_key` shares any variables with but is not equal to a previously
        registered key.
      KeyError: If `reuse` is `True` but `layer_key` was not previously
        registered.

    Returns:
      The `FisherBlock` registered under `layer_key`. If `layer_key` was already
      registered, this will be the previously registered `FisherBlock`.
    """
        if reuse is VARIABLE_SCOPE:
            reuse = variable_scope.get_variable_scope().reuse

        if reuse is True or (reuse is variable_scope.AUTO_REUSE
                             and layer_key in self.fisher_blocks):
            result = self.fisher_blocks[layer_key]
            if type(result) != type(fisher_block):  # pylint: disable=unidiomatic-typecheck
                raise ValueError(
                    "Attempted to register FisherBlock of type %s when existing "
                    "FisherBlock has type %s." %
                    (type(fisher_block), type(result)))
            return result
        if reuse is False and layer_key in self.fisher_blocks:
            raise ValueError(
                "FisherBlock for %s is already in LayerCollection." %
                (layer_key, ))

        # Insert fisher_block into self.fisher_blocks.
        if layer_key in self.fisher_blocks:
            raise ValueError("Duplicate registration: {}".format(layer_key))
        # Raise an error if any variable in layer_key has been registered in any
        # other blocks.
        variable_to_block = {
            var: (params, block)
            for (params, block) in self.fisher_blocks.items()
            for var in utils.ensure_sequence(params)
        }
        for variable in utils.ensure_sequence(layer_key):
            if variable in variable_to_block:
                prev_key, prev_block = variable_to_block[variable]
                raise ValueError(
                    "Attempted to register layer_key {} with block {}, but variable {}"
                    " was already registered in key {} with block {}.".format(
                        layer_key, fisher_block, variable, prev_key,
                        prev_block))
        self.fisher_blocks[layer_key] = fisher_block
        return fisher_block
Пример #2
0
  def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
    """Validates and registers the layer_key associated with the fisher_block.

    Args:
      layer_key: A variable or tuple of variables. The key to check for in
          existing registrations and to register if valid.
      fisher_block: The associated `FisherBlock`.
      reuse: Method to use for inserting new `FisherBlock`s. One of True, False,
        or 'VARIABLE_SCOPE'.

    Raises:
      ValueError: If `layer_key` was already registered and reuse is `False`,
        if `layer_key` was registered with a different block type, or if
        `layer_key` shares any variables with but is not equal to a previously
        registered key.
      KeyError: If `reuse` is `True` but `layer_key` was not previously
        registered.

    Returns:
      The `FisherBlock` registered under `layer_key`. If `layer_key` was already
      registered, this will be the previously registered `FisherBlock`.
    """
    if reuse is VARIABLE_SCOPE:
      reuse = variable_scope.get_variable_scope().reuse

    if reuse is True or (reuse is variable_scope.AUTO_REUSE and
                         layer_key in self.fisher_blocks):
      result = self.fisher_blocks[layer_key]
      if type(result) != type(fisher_block):  # pylint: disable=unidiomatic-typecheck
        raise ValueError(
            "Attempted to register FisherBlock of type %s when existing "
            "FisherBlock has type %s." % (type(fisher_block), type(result)))
      return result
    if reuse is False and layer_key in self.fisher_blocks:
      raise ValueError("FisherBlock for %s is already in LayerCollection." %
                       (layer_key,))

    # Insert fisher_block into self.fisher_blocks.
    if layer_key in self.fisher_blocks:
      raise ValueError("Duplicate registration: {}".format(layer_key))
    # Raise an error if any variable in layer_key has been registered in any
    # other blocks.
    variable_to_block = {
        var: (params, block)
        for (params, block) in self.fisher_blocks.items()
        for var in utils.ensure_sequence(params)
    }
    for variable in utils.ensure_sequence(layer_key):
      if variable in variable_to_block:
        prev_key, prev_block = variable_to_block[variable]
        raise ValueError(
            "Attempted to register layer_key {} with block {}, but variable {}"
            " was already registered in key {} with block {}.".format(
                layer_key, fisher_block, variable, prev_key, prev_block))
    self.fisher_blocks[layer_key] = fisher_block
    return fisher_block
Пример #3
0
 def _get_linked_approx(self, params):
     """If params were linked, return their specified approximation."""
     params_set = frozenset(utils.ensure_sequence(params))
     if params_set in self.linked_parameters:
         return self.linked_parameters[params_set]
     else:
         return None
Пример #4
0
 def registered_variables(self):
     """A tuple of all of the variables currently registered."""
     tuple_of_tuples = (utils.ensure_sequence(key)
                        for key, block in six.iteritems(self.fisher_blocks))
     flat_tuple = tuple(item for tuple_ in tuple_of_tuples
                        for item in tuple_)
     return flat_tuple
Пример #5
0
 def _get_linked_approx(self, params):
   """If params were linked, return their specified approximation."""
   params_set = frozenset(utils.ensure_sequence(params))
   if params_set in self.linked_parameters:
     return self.linked_parameters[params_set]
   else:
     return None
Пример #6
0
 def __init__(self,
              params_grads,
              batch_size):
   self._params_grads = tuple(utils.ensure_sequence(params_grad)
                              for params_grad in params_grads)
   self._batch_size = batch_size
   super(NaiveDiagonalFactor, self).__init__()
Пример #7
0
 def __init__(self,
              params_grads,
              batch_size):
   self._params_grads = tuple(utils.ensure_sequence(params_grad)
                              for params_grad in params_grads)
   self._batch_size = batch_size
   super(NaiveDiagonalFactor, self).__init__()
Пример #8
0
    def define_linked_parameters(self, params, approximation=None):
        """Identify a set of parameters that should be grouped together.

    During automatic graph scanning, any matches containing variables that have
    been identified as part of a linked group will be filtered out unless
    the match parameters are exactly equal to the ones specified in the linked
    group.

    Args:
      params: A variable, or a tuple or list of variables. The variables
        to be linked.
      approximation: Optional string specifying the type of approximation to use
        for these variables. If unspecified, this layer collection's default
        approximation for the layer type will be used.

    Raises:
      ValueError: If the parameters were already registered in a layer or
        identified as part of an incompatible group.
    """
        params = frozenset(utils.ensure_sequence(params))

        # Check if any of the variables in 'params' is already in
        # 'self.fisher_blocks.keys()'.
        for registered_params, fisher_block in self.fisher_blocks.items():
            registered_params_set = set(
                utils.ensure_sequence(registered_params))
            for variable in params:
                if (variable in registered_params_set
                        and params != registered_params_set):
                    raise ValueError(
                        "Can't link parameters {}, variable {} was already registered in "
                        "group {} with layer {}".format(
                            params, variable, registered_params, fisher_block))

        # Check if any of the variables in 'params' is already in
        # 'self.linked_parameters'.
        for variable in params:
            for other_linked_params in self.linked_parameters:
                if variable in other_linked_params:
                    raise ValueError(
                        "Can't link parameters {}, variable {} was already "
                        "linked in group {}.".format(params, variable,
                                                     other_linked_params))
        self._linked_parameters[params] = approximation
Пример #9
0
  def define_linked_parameters(self, params, approximation=None):
    """Identify a set of parameters that should be grouped together.

    During automatic graph scanning, any matches containing variables that have
    been identified as part of a linked group will be filtered out unless
    the match parameters are exactly equal to the ones specified in the linked
    group.

    Args:
      params: A variable, or a tuple or list of variables. The variables
        to be linked.
      approximation: Optional string specifying the type of approximation to use
        for these variables. If unspecified, this layer collection's default
        approximation for the layer type will be used.

    Raises:
      ValueError: If the parameters were already registered in a layer or
        identified as part of an incompatible group.
    """
    params = frozenset(utils.ensure_sequence(params))

    # Check if any of the variables in 'params' is already in
    # 'self.fisher_blocks.keys()'.
    for registered_params, fisher_block in self.fisher_blocks.items():
      registered_params_set = set(utils.ensure_sequence(registered_params))
      for variable in params:
        if (variable in registered_params_set and
            params != registered_params_set):
          raise ValueError(
              "Can't link parameters {}, variable {} was already registered in "
              "group {} with layer {}".format(params, variable,
                                              registered_params, fisher_block))

    # Check if any of the variables in 'params' is already in
    # 'self.linked_parameters'.
    for variable in params:
      for other_linked_params in self.linked_parameters:
        if variable in other_linked_params:
          raise ValueError("Can't link parameters {}, variable {} was already "
                           "linked in group {}.".format(params, variable,
                                                        other_linked_params))
    self._linked_parameters[params] = approximation
Пример #10
0
 def get_use_count_map(self):
     """Returns a dict of variables to their number of registrations."""
     # TODO(b/70283403): Reimplement this in the old way, where each
     # registration function would be responsible for incrementing the count.
     # Also, this version has a bug: it won't do the right thing for generic
     # registration for parameters that are shared.  i.e. it won't set the use
     # count to infinity.
     vars_to_uses = defaultdict(int)
     for key, block in six.iteritems(self.fisher_blocks):
         n = (block.num_inputs() * block.num_registered_minibatches
              if isinstance(block, (fb.FullyConnectedSeriesFB,
                                    fb.FullyConnectedMultiIndepFB)) else
              block.num_registered_minibatches)
         key = utils.ensure_sequence(key)
         for k in key:
             vars_to_uses[k] += n
     return vars_to_uses
Пример #11
0
 def get_use_count_map(self):
   """Returns a dict of variables to their number of registrations."""
   # TODO(b/70283403): Reimplement this in the old way, where each
   # registration function would be responsible for incrementing the count.
   # Also, this version has a bug: it won't do the right thing for generic
   # registration for parameters that are shared.  i.e. it won't set the use
   # count to infinity.
   vars_to_uses = defaultdict(int)
   for key, block in six.iteritems(self.fisher_blocks):
     n = (
         block.num_inputs()*block.num_registered_minibatches if isinstance(
             block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB))
         else block.num_registered_minibatches)
     key = utils.ensure_sequence(key)
     for k in key:
       vars_to_uses[k] += n
   return vars_to_uses
Пример #12
0
 def registered_variables(self):
   """A tuple of all of the variables currently registered."""
   tuple_of_tuples = (utils.ensure_sequence(key) for key, block
                      in six.iteritems(self.fisher_blocks))
   flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
   return flat_tuple