def testRepeatedAdds(self): a = array_ops.constant([[1., 2.], [3., 4.]]) b = array_ops.constant([[5., 6.], [7., 8.]]) c = a + b + a # note that a appears twice in this graph sub_graph = utils.SubGraph((c, )) self.assertTrue(sub_graph.is_member(a)) self.assertTrue(sub_graph.is_member(b)) self.assertTrue(sub_graph.is_member(c))
def testFilterList(self): a = array_ops.constant([[1., 2.], [3., 4.]]) b = array_ops.constant([[5., 6.], [7., 8.]]) c = a + b d = a * b sub_graph = utils.SubGraph((c, )) input_list = [b, d] filtered_list = sub_graph.filter_list(input_list) self.assertEqual(filtered_list, [b])
def testBasicGraph(self): a = array_ops.constant([[1., 2.], [3., 4.]]) b = array_ops.constant([[5., 6.], [7., 8.]]) c = a + b d = a * b sub_graph = utils.SubGraph((c, )) self.assertTrue(sub_graph.is_member(a)) self.assertTrue(sub_graph.is_member(b)) self.assertTrue(sub_graph.is_member(c)) self.assertFalse(sub_graph.is_member(d))
def testVariableUses(self): with ops.Graph().as_default(): var = variable_scope.get_variable('var', shape=[10, 10]) resource_var = variable_scope.get_variable('resource_var', shape=[10, 10], use_resource=True) x = array_ops.zeros([3, 10]) z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) z1 = math_ops.matmul(x, resource_var) sub_graph = utils.SubGraph((z0, z1)) self.assertEqual(2, sub_graph.variable_uses(var)) self.assertEqual(1, sub_graph.variable_uses(resource_var))
def create_subgraph(self): if not self.losses: raise ValueError("Must have at least one registered loss.") inputs_to_losses = nest.flatten( tuple(loss.inputs for loss in self.losses)) self._subgraph = utils.SubGraph(inputs_to_losses)