예제 #1
0
    def test_merge_infer_only_second_executable(self):
        graph = build_graph_with_attrs(
            nodes_with_attrs=self.nodes,
            edges_with_attrs=self.edges,
            update_nodes_attributes=[
                ('first', {'executable': False, 'value': np.ones([2, 2]), 'shape': int64_array([2, 2])}),
                ('second', {'executable': True, 'value': np.zeros([4, 4]), 'shape': int64_array([4, 4])})
            ]
        )

        ref_graph = build_graph_with_attrs(
            nodes_with_attrs=self.nodes,
            edges_with_attrs=self.edges,
            update_nodes_attributes=[
                ('first', {'executable': False, 'value': np.ones([2, 2]), 'shape': int64_array([2, 2])}),
                ('second', {'executable': True, 'value': np.zeros([4, 4]), 'shape': int64_array([4, 4])}),
                ('merge', {'is_not_fully_inferred': False}),
                ('merge_output', {'shape': int64_array([4, 4]), 'value': np.zeros([4, 4])})
            ]
        )

        tested_class = Merge(graph=graph, attrs={})
        node = Node(graph, 'merge')
        tested_class.merge_infer(node)

        (flag, resp) = compare_graphs(graph, ref_graph, 'merge_output', check_op_attrs=True)
        self.assertTrue(flag, resp)
예제 #2
0
    def test_merge_infer_simple_case_one_executable(self):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges)

        # We should propagate value of the first input since only this input is executable
        graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                           edges_with_attrs=self.edges,
                                           update_nodes_attributes=[('merge_output', {'shape': np.array([2, 2]),
                                                                                      'value': np.ones((2, 2))}),
                                                                    ('merge', {'is_not_fully_inferred': False})])

        tested_class = Merge(graph=graph, attrs={})
        node = Node(graph, 'merge')
        tested_class.merge_infer(node)

        (flag, resp) = compare_graphs(graph, graph_ref, 'merge_output', check_op_attrs=True)
        self.assertTrue(flag, resp)
예제 #3
0
    def test_merge_infer_complex_case(self):
        """
        Case as in cycles when in first visit only one input are inferred and in the second -- both.
        """
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('first', {'is_partial_inferred': False,
                                                                           'value': None}),
                                                                ('second', {'executable': True})])

        # In first visit we should propagate only shapes
        graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                           edges_with_attrs=self.edges,
                                           update_nodes_attributes=[('second', {'executable': True}),
                                                                    ('first', {'is_partial_inferred': False,
                                                                               'value': None}),
                                                                    ('merge_output', {'shape': np.array([2, 2]),
                                                                                      'value': None}),
                                                                    ('merge', {'is_not_fully_inferred': True})])
        tested_class = Merge(graph=graph, attrs={})
        node = Node(graph, 'merge')
        tested_class.merge_infer(node)

        (flag, resp) = compare_graphs(graph, graph_ref, 'merge_output', check_op_attrs=True)
        self.assertTrue(flag, resp)

        # Imitate that inputs nodes now is inferred
        graph.node['first']['is_partial_inferred'] = True

        # Run infer second time
        tested_class = Merge(graph=graph, attrs={})
        node = Node(graph, 'merge')
        tested_class.merge_infer(node)

        graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                           edges_with_attrs=self.edges,
                                           update_nodes_attributes=[('second', {'executable': True}),
                                                                    ('first', {'is_partial_inferred': True,
                                                                               'value': None}),
                                                                    ('merge_output', {'shape': np.array([2, 2]),
                                                                                      'value': None}),
                                                                    ('merge', {'is_not_fully_inferred': False})])
        (flag, resp) = compare_graphs(graph, graph_ref, 'merge_output', check_op_attrs=True)
        self.assertTrue(flag, resp)