コード例 #1
0
    def test_graph_relations_in_format_of_layer_names(self):
        l1 = layers.Input(1, name='input')
        l2 = layers.Sigmoid(2, name='sigmoid-2')
        l3 = layers.Sigmoid(3, name='sigmoid-3')
        l4 = layers.Sigmoid(4, name='sigmoid-4')
        lc = layers.Concatenate(name='concat')

        graph = LayerGraph()

        graph.connect_layers(l1, l2)
        graph.connect_layers(l2, l3)
        graph.connect_layers(l3, l4)

        graph.connect_layers(l2, lc)
        graph.connect_layers(l3, lc)

        actual_graph = graph.layer_names_only()
        expected_graph = [
            ('input', ['sigmoid-2']),
            ('sigmoid-2', ['sigmoid-3', 'concat']),
            ('sigmoid-3', ['sigmoid-4', 'concat']),
            ('sigmoid-4', []),
            ('concat', []),
        ]

        self.assertListEqual(actual_graph, expected_graph)