def test_find_nodes_by_inputs_and_outputs(self): inputs = input_layer_lib.Input((10, )) unconnected_inputs = input_layer_lib.Input((10, )) x = layers.Dense(8)(inputs) y = layers.Dense(6)(x) output = layers.Dense(4)(y) nodes_in_graph = functional_utils.find_nodes_by_inputs_and_outputs( x, output) self.assertLen(nodes_in_graph, 2) expected_nodes = [output.node, y.node] self.assertCountEqual(nodes_in_graph, expected_nodes) # Make sure we raise error if we specify invalid input/output pair with self.assertRaisesRegex(ValueError, 'Found input tensor cannot be reached'): functional_utils.find_nodes_by_inputs_and_outputs(output, x) with self.assertRaisesRegex(ValueError, 'Found input tensor cannot be reached'): functional_utils.find_nodes_by_inputs_and_outputs( unconnected_inputs, output) with self.assertRaisesRegex( ValueError, 'Found unvisited input tensors that are disconnected'): functional_utils.find_nodes_by_inputs_and_outputs( [inputs, unconnected_inputs], output)
def test_find_nodes_by_inputs_and_outputs_with_complicated_network(self): input1 = input_layer_lib.Input((10, )) input2 = input_layer_lib.Input((10, )) input3 = input_layer_lib.Input((10, )) unconnected_input = input_layer_lib.Input((10, )) dense1 = layers.Dense(4, name='dense1') dense2 = layers.Dense(4, name='dense2') # dense1 are shared between input1 and input2 a = dense1(input1) b = dense1(input2) c = layers.Add()([a, b]) d = dense2(input3) e = layers.Add()([c, d]) # There are 5 nodes (invoke of __call__) in the graph. nodes = functional_utils.find_nodes_by_inputs_and_outputs(input1, a) self.assertCountEqual(nodes, [a.node]) nodes = functional_utils.find_nodes_by_inputs_and_outputs(input2, b) self.assertCountEqual(nodes, [b.node]) nodes = functional_utils.find_nodes_by_inputs_and_outputs( [input2, input1], c) # This should contains 2 dense call and 1 add self.assertCountEqual(nodes, [a.node, b.node, c.node]) # Missing input3 with self.assertRaisesRegex(ValueError, 'Found input tensor cannot be reached'): functional_utils.find_nodes_by_inputs_and_outputs([input1, input2], e) nodes = functional_utils.find_nodes_by_inputs_and_outputs( [input1, input2, input3], e) self.assertCountEqual(nodes, [a.node, b.node, c.node, d.node, e.node]) # Make sure we can create from intermediate tensors nodes = functional_utils.find_nodes_by_inputs_and_outputs( [a, b, input3], e) self.assertCountEqual(nodes, [c.node, d.node, e.node]) # Also make sure we can add intermediate outputs nodes = functional_utils.find_nodes_by_inputs_and_outputs( [a, b, input3], [d, e]) self.assertCountEqual(nodes, [c.node, d.node, e.node]) # input1 and 2 are not needed for computing d with self.assertRaisesRegex( ValueError, 'Found unvisited input tensors that are disconnected'): functional_utils.find_nodes_by_inputs_and_outputs( [input1, input2, input3], d) with self.assertRaisesRegex( ValueError, 'Found unvisited input tensors that are disconnected'): functional_utils.find_nodes_by_inputs_and_outputs( [a, b, input3, unconnected_input], [e, d, c])