def testRepeatedAdds(self): a = tf.constant([[1., 2.], [3., 4.]]) b = tf.constant([[5., 6.], [7., 8.]]) c = a + b + a # note that a appears twice in this graph sub_graph = utils.SubGraph((c,)) self.assertTrue(sub_graph.is_member(a)) self.assertTrue(sub_graph.is_member(b)) self.assertTrue(sub_graph.is_member(c))
def testFilterList(self): a = tf.constant([[1., 2.], [3., 4.]]) b = tf.constant([[5., 6.], [7., 8.]]) c = a + b d = a * b sub_graph = utils.SubGraph((c,)) input_list = [b, d] filtered_list = sub_graph.filter_list(input_list) self.assertEqual(filtered_list, [b])
def testBasicGraph(self): a = tf.constant([[1., 2.], [3., 4.]]) b = tf.constant([[5., 6.], [7., 8.]]) c = a + b d = a * b sub_graph = utils.SubGraph((c,)) self.assertTrue(sub_graph.is_member(a)) self.assertTrue(sub_graph.is_member(b)) self.assertTrue(sub_graph.is_member(c)) self.assertFalse(sub_graph.is_member(d))
def testVariableUses(self): with tf.Graph().as_default(): var = tf.get_variable('var', shape=[10, 10]) resource_var = tf.get_variable( 'resource_var', shape=[10, 10], use_resource=True) x = tf.zeros([3, 10]) z0 = tf.matmul(x, var) + tf.matmul(x, var) z1 = tf.matmul(x, resource_var) sub_graph = utils.SubGraph((z0, z1)) self.assertEqual(2, sub_graph.variable_uses(var)) self.assertEqual(1, sub_graph.variable_uses(resource_var))
def testVariableUsesRelayOps(self): with tf.Graph().as_default(): a = tf.get_variable("a", shape=[2, 2]) b = tf.get_variable("b", shape=[2, 2]) ai = tf.identity(a) c = tf.matmul(ai, b) d = tf.matmul(ai, b) sub_graph = utils.SubGraph((c, d)) self.assertEqual(2, sub_graph.variable_uses(a)) self.assertEqual(2, sub_graph.variable_uses(b))
def register_layers(layer_collection, varlist, batch_size=None): """Walk the graph and register all layers to layer_collection. Parameters used multiple times in the graph need to be handled differently depending on context: this could either mean the parameters represent an RNN layer, or that the graph has been replicated as multiple "towers" to allow data parallelism. We differentiate these cases by examining the loss functions registered by layer_collection: if losses have been registered multiple times with reuse=True, we separate the subgraphs corresponding to each tower and register layers independently for each with reuse=True. Args: layer_collection: A `LayerCollection` to use for registering layers. varlist: A list of the variables in the graph. batch_size: A `int` representing the batch size. Needs to specified if registering generic variables that don't match any layer patterns or if time/uses is folded. If the time/uses dimension is merged with batch then this is used to infer number of uses/time-steps. Returns: A `dict` of the entries registered to layer_collection.fisher_blocks. Raises: ValueError: If not all losses were registered the same number of times. If any variables specified as part of linked groups were not matched with their group. If the same variable is used in multiple layers types (e.g. fully connected and 2d convolution), or if the same variable is used in multiple layers of a type that doesn't support shared parameters. AmbiguousRegistrationError: If any variables must be registered as generic and batch_size is not specified, or if even after filtering, there are matches with overlapping but unequal sets of variables (see filter_records). """ original_fisher_blocks = layer_collection.fisher_blocks.copy() user_registered_variables = set() for params in layer_collection.fisher_blocks.keys(): for variable in ensure_sequence(params): user_registered_variables.add(variable) user_registered_variables = frozenset(user_registered_variables) if not layer_collection.losses: register_subgraph_layers( layer_collection, varlist, user_registered_variables=user_registered_variables, batch_size=batch_size) else: inputs_by_loss = tuple(tuple(loss.inputs for loss in loss_list) for loss_list in layer_collection.towers_by_loss) num_towers = len(inputs_by_loss[0]) if not all( (len(input_tensors) == num_towers for input_tensors in inputs_by_loss)): raise ValueError( 'If losses are registered with reuse=True, each name must be ' 'registered the same number of times.') for tower_number, tower_input_tensors in enumerate(zip(*inputs_by_loss)): reuse = (tower_number > 0) with tf.variable_scope('tower_%d' % tower_number, reuse=reuse): subgraph = utils.SubGraph(tower_input_tensors) register_subgraph_layers( layer_collection, varlist, user_registered_variables=user_registered_variables, reuse=reuse, batch_size=batch_size, subgraph=subgraph) fisher_blocks = layer_collection.fisher_blocks return { params: fisher_blocks[params] for params in set(fisher_blocks) - set(original_fisher_blocks) }