def test_find_nodes_by_inputs_and_outputs(self):
        inputs = input_layer_lib.Input((10, ))
        unconnected_inputs = input_layer_lib.Input((10, ))
        x = layers.Dense(8)(inputs)
        y = layers.Dense(6)(x)
        output = layers.Dense(4)(y)

        nodes_in_graph = functional_utils.find_nodes_by_inputs_and_outputs(
            x, output)
        self.assertLen(nodes_in_graph, 2)
        expected_nodes = [output.node, y.node]
        self.assertCountEqual(nodes_in_graph, expected_nodes)

        # Make sure we raise error if we specify invalid input/output pair
        with self.assertRaisesRegex(ValueError,
                                    'Found input tensor cannot be reached'):
            functional_utils.find_nodes_by_inputs_and_outputs(output, x)

        with self.assertRaisesRegex(ValueError,
                                    'Found input tensor cannot be reached'):
            functional_utils.find_nodes_by_inputs_and_outputs(
                unconnected_inputs, output)

        with self.assertRaisesRegex(
                ValueError,
                'Found unvisited input tensors that are disconnected'):
            functional_utils.find_nodes_by_inputs_and_outputs(
                [inputs, unconnected_inputs], output)
    def test_find_nodes_by_inputs_and_outputs_with_complicated_network(self):
        input1 = input_layer_lib.Input((10, ))
        input2 = input_layer_lib.Input((10, ))
        input3 = input_layer_lib.Input((10, ))
        unconnected_input = input_layer_lib.Input((10, ))

        dense1 = layers.Dense(4, name='dense1')
        dense2 = layers.Dense(4, name='dense2')
        # dense1 are shared between input1 and input2
        a = dense1(input1)
        b = dense1(input2)

        c = layers.Add()([a, b])
        d = dense2(input3)
        e = layers.Add()([c, d])
        # There are 5 nodes (invoke of __call__) in the graph.

        nodes = functional_utils.find_nodes_by_inputs_and_outputs(input1, a)
        self.assertCountEqual(nodes, [a.node])

        nodes = functional_utils.find_nodes_by_inputs_and_outputs(input2, b)
        self.assertCountEqual(nodes, [b.node])

        nodes = functional_utils.find_nodes_by_inputs_and_outputs(
            [input2, input1], c)
        # This should contains 2 dense call and 1 add
        self.assertCountEqual(nodes, [a.node, b.node, c.node])

        # Missing input3
        with self.assertRaisesRegex(ValueError,
                                    'Found input tensor cannot be reached'):
            functional_utils.find_nodes_by_inputs_and_outputs([input1, input2],
                                                              e)

        nodes = functional_utils.find_nodes_by_inputs_and_outputs(
            [input1, input2, input3], e)
        self.assertCountEqual(nodes, [a.node, b.node, c.node, d.node, e.node])

        # Make sure we can create from intermediate tensors
        nodes = functional_utils.find_nodes_by_inputs_and_outputs(
            [a, b, input3], e)
        self.assertCountEqual(nodes, [c.node, d.node, e.node])
        # Also make sure we can add intermediate outputs
        nodes = functional_utils.find_nodes_by_inputs_and_outputs(
            [a, b, input3], [d, e])
        self.assertCountEqual(nodes, [c.node, d.node, e.node])

        # input1 and 2 are not needed for computing d
        with self.assertRaisesRegex(
                ValueError,
                'Found unvisited input tensors that are disconnected'):
            functional_utils.find_nodes_by_inputs_and_outputs(
                [input1, input2, input3], d)

        with self.assertRaisesRegex(
                ValueError,
                'Found unvisited input tensors that are disconnected'):
            functional_utils.find_nodes_by_inputs_and_outputs(
                [a, b, input3, unconnected_input], [e, d, c])