Example #1
0
    def test2(self):
        """
        Case with non-constant input to init.
        Nothing should happen with graph.
        """
        pattern_matcher = BackEdgeSimpleInputMatcher()
        pattern = pattern_matcher.pattern()

        graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'],
                                       edges_with_attrs=pattern['edges'],
                                       new_nodes_with_attrs=[
                                           ('cycle_data', {
                                               'kind': 'data'
                                           }),
                                           ('condition', {
                                               'kind': 'data'
                                           }),
                                           ('init', {
                                               'kind': 'data',
                                               'shape': np.array([1, 3])
                                           }),
                                           ('Enter', {
                                               'kind': 'op',
                                               'op': 'Enter'
                                           }),
                                       ],
                                       new_edges_with_attrs=[
                                           ('Enter', 'init'),
                                           ('condition', 'BackEdge', {
                                               'in': 2
                                           }), ('init', 'BackEdge', {
                                               'in': 0
                                           }),
                                           ('cycle_data', 'BackEdge', {
                                               'in': 1
                                           })
                                       ])

        pattern_matcher.find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(
            nodes_with_attrs=pattern['nodes'],
            edges_with_attrs=pattern['edges'],
            new_nodes_with_attrs=[
                ('cycle_data', {
                    'kind': 'data'
                }),
                ('condition', {
                    'kind': 'data'
                }),
                ('init', {
                    'kind': 'data',
                    'shape': np.array([1, 3])
                }),
                ('Enter', {
                    'kind': 'op',
                    'op': 'Enter'
                }),
            ],
            new_edges_with_attrs=[('Enter', 'init'),
                                  ('condition', 'BackEdge', {
                                      'in': 2
                                  }), ('init', 'BackEdge', {
                                      'in': 0
                                  }), ('cycle_data', 'BackEdge', {
                                      'in': 1
                                  })],
        )

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'BackEdge',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)