예제 #1
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testRepeatedAdds(self):
   a = tf.constant([[1., 2.], [3., 4.]])
   b = tf.constant([[5., 6.], [7., 8.]])
   c = a + b + a  # note that a appears twice in this graph
   sub_graph = utils.SubGraph((c,))
   self.assertTrue(sub_graph.is_member(a))
   self.assertTrue(sub_graph.is_member(b))
   self.assertTrue(sub_graph.is_member(c))
예제 #2
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testFilterList(self):
   a = tf.constant([[1., 2.], [3., 4.]])
   b = tf.constant([[5., 6.], [7., 8.]])
   c = a + b
   d = a * b
   sub_graph = utils.SubGraph((c,))
   input_list = [b, d]
   filtered_list = sub_graph.filter_list(input_list)
   self.assertEqual(filtered_list, [b])
예제 #3
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testBasicGraph(self):
   a = tf.constant([[1., 2.], [3., 4.]])
   b = tf.constant([[5., 6.], [7., 8.]])
   c = a + b
   d = a * b
   sub_graph = utils.SubGraph((c,))
   self.assertTrue(sub_graph.is_member(a))
   self.assertTrue(sub_graph.is_member(b))
   self.assertTrue(sub_graph.is_member(c))
   self.assertFalse(sub_graph.is_member(d))
예제 #4
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testVariableUses(self):
   with tf.Graph().as_default():
     var = tf.get_variable('var', shape=[10, 10])
     resource_var = tf.get_variable(
         'resource_var', shape=[10, 10], use_resource=True)
     x = tf.zeros([3, 10])
     z0 = tf.matmul(x, var) + tf.matmul(x, var)
     z1 = tf.matmul(x, resource_var)
     sub_graph = utils.SubGraph((z0, z1))
     self.assertEqual(2, sub_graph.variable_uses(var))
     self.assertEqual(1, sub_graph.variable_uses(resource_var))
예제 #5
0
파일: utils_test.py 프로젝트: leox1v/kfac
  def testVariableUsesRelayOps(self):
    with tf.Graph().as_default():
      a = tf.get_variable("a", shape=[2, 2])
      b = tf.get_variable("b", shape=[2, 2])
      ai = tf.identity(a)
      c = tf.matmul(ai, b)
      d = tf.matmul(ai, b)

      sub_graph = utils.SubGraph((c, d))
      self.assertEqual(2, sub_graph.variable_uses(a))
      self.assertEqual(2, sub_graph.variable_uses(b))
예제 #6
0
파일: graph_search.py 프로젝트: leox1v/kfac
def register_layers(layer_collection, varlist, batch_size=None):
  """Walk the graph and register all layers to layer_collection.

  Parameters used multiple times in the graph need to be handled differently
  depending on context: this could either mean the parameters represent an
  RNN layer, or that the graph has been replicated as multiple "towers"
  to allow data parallelism.
  We differentiate these cases by examining the loss functions registered by
  layer_collection: if losses have been registered multiple times with
  reuse=True, we separate the subgraphs corresponding to each tower and
  register layers independently for each with reuse=True.

  Args:
    layer_collection: A `LayerCollection` to use for registering layers.
    varlist: A list of the variables in the graph.
    batch_size: A `int` representing the batch size. Needs to specified if
      registering generic variables that don't match any layer patterns or
      if time/uses is folded. If the time/uses dimension is merged with
      batch then this is used to infer number of uses/time-steps.

  Returns:
    A `dict` of the entries registered to layer_collection.fisher_blocks.

  Raises:
    ValueError: If not all losses were registered the same number of times.
      If any variables specified as part of linked groups were not
      matched with their group.
      If the same variable is used in multiple layers types
      (e.g. fully connected and 2d convolution), or if the same variable is
      used in multiple layers of a type that doesn't support shared parameters.
    AmbiguousRegistrationError: If any variables must be registered as generic
      and batch_size is not specified, or if even after filtering, there are
      matches with overlapping but unequal sets of variables (see
      filter_records).
  """
  original_fisher_blocks = layer_collection.fisher_blocks.copy()
  user_registered_variables = set()
  for params in layer_collection.fisher_blocks.keys():
    for variable in ensure_sequence(params):
      user_registered_variables.add(variable)
  user_registered_variables = frozenset(user_registered_variables)

  if not layer_collection.losses:
    register_subgraph_layers(
        layer_collection,
        varlist,
        user_registered_variables=user_registered_variables,
        batch_size=batch_size)
  else:
    inputs_by_loss = tuple(tuple(loss.inputs for loss in loss_list)
                           for loss_list in layer_collection.towers_by_loss)

    num_towers = len(inputs_by_loss[0])

    if not all(
        (len(input_tensors) == num_towers for input_tensors in inputs_by_loss)):
      raise ValueError(
          'If losses are registered with reuse=True, each name must be '
          'registered the same number of times.')

    for tower_number, tower_input_tensors in enumerate(zip(*inputs_by_loss)):
      reuse = (tower_number > 0)
      with tf.variable_scope('tower_%d' % tower_number, reuse=reuse):
        subgraph = utils.SubGraph(tower_input_tensors)
        register_subgraph_layers(
            layer_collection,
            varlist,
            user_registered_variables=user_registered_variables,
            reuse=reuse,
            batch_size=batch_size,
            subgraph=subgraph)

  fisher_blocks = layer_collection.fisher_blocks
  return {
      params: fisher_blocks[params]
      for params in set(fisher_blocks) - set(original_fisher_blocks)
  }