Esempio n. 1
0
 def testRepeatedAdds(self):
     a = array_ops.constant([[1., 2.], [3., 4.]])
     b = array_ops.constant([[5., 6.], [7., 8.]])
     c = a + b + a  # note that a appears twice in this graph
     sub_graph = utils.SubGraph((c, ))
     self.assertTrue(sub_graph.is_member(a))
     self.assertTrue(sub_graph.is_member(b))
     self.assertTrue(sub_graph.is_member(c))
Esempio n. 2
0
 def testFilterList(self):
     a = array_ops.constant([[1., 2.], [3., 4.]])
     b = array_ops.constant([[5., 6.], [7., 8.]])
     c = a + b
     d = a * b
     sub_graph = utils.SubGraph((c, ))
     input_list = [b, d]
     filtered_list = sub_graph.filter_list(input_list)
     self.assertEqual(filtered_list, [b])
Esempio n. 3
0
 def testBasicGraph(self):
     a = array_ops.constant([[1., 2.], [3., 4.]])
     b = array_ops.constant([[5., 6.], [7., 8.]])
     c = a + b
     d = a * b
     sub_graph = utils.SubGraph((c, ))
     self.assertTrue(sub_graph.is_member(a))
     self.assertTrue(sub_graph.is_member(b))
     self.assertTrue(sub_graph.is_member(c))
     self.assertFalse(sub_graph.is_member(d))
Esempio n. 4
0
 def testVariableUses(self):
     with ops.Graph().as_default():
         var = variable_scope.get_variable('var', shape=[10, 10])
         resource_var = variable_scope.get_variable('resource_var',
                                                    shape=[10, 10],
                                                    use_resource=True)
         x = array_ops.zeros([3, 10])
         z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
         z1 = math_ops.matmul(x, resource_var)
         sub_graph = utils.SubGraph((z0, z1))
         self.assertEqual(2, sub_graph.variable_uses(var))
         self.assertEqual(1, sub_graph.variable_uses(resource_var))
Esempio n. 5
0
 def create_subgraph(self):
     if not self.losses:
         raise ValueError("Must have at least one registered loss.")
     inputs_to_losses = nest.flatten(
         tuple(loss.inputs for loss in self.losses))
     self._subgraph = utils.SubGraph(inputs_to_losses)