Ejemplo n.º 1
0
    def test_backward_bfs_multi_consumer_data_nodes(self):
        # Placeholder-> Mul -> Result
        # Const      -/    \- Result2

        graph = build_graph(
            {
                **regular_op_with_shaped_data('parameter', [1], {
                                                  'op': 'Parameter'
                                              }),
                **valued_const_with_data('const', int64_array([5])),
                **regular_op_with_shaped_data('mul', [1], {'op': 'Mul'}),
                **result('result'),
                **result('result2'),
            }, [
                *connect('parameter', '0:mul'),
                *connect('const', '1:mul'),
                *connect('mul:0', 'result'),
                *connect_data('mul', 'result2'),
            ])

        res = common_bfs(Node(graph, 'result'), ['Mul'], ['Parameter'],
                         is_backward=True,
                         attr_to_check='op',
                         follow_multi_consumer_data_nodes=True)
        self.assertTrue(
            len(res) == 1,
            'The multi-consumer data node "mul_d" was not followed')

        res = common_bfs(Node(graph, 'result'), ['Mul'], ['Parameter'],
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(
            len(res) == 0, 'The multi-consumer data node "mul_d" was followed')
Ejemplo n.º 2
0
    def test_backward_bfs_check_op_instead_of_type(self):
        # Placeholder->ScaleShift->Mul1->Add1---->Concat
        #             `----------->Add2->Mul2--'
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'add_2'),
                             ('scaleshift_1_w', 'scaleshift_1'),
                             ('scaleshift_1', 'scaleshift_1_data'),
                             ('scaleshift_1_data', 'mul_1'),
                             ('mul_1', 'mul_1_data'), ('mul_1_data', 'add_1'),
                             ('add_1', 'add_1_data'), ('add_2', 'add_2_data'),
                             ('add_2_data', 'mul_2'), ('mul_2', 'mul_2_data'),
                             ('add_1_data', 'concat_1'),
                             ('mul_2_data', 'concat_1'),
                             ('concat_1', 'concat_1_data'),
                             ('concat_1_data', 'op_output')])

        res = common_bfs(Node(graph, 'concat_1'), ['Mul', 'Add'],
                         ['Parameter'],
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(len(res) == 0, 'Smth went wrong with bfs')

        res = common_bfs(Node(graph, 'concat_1'), ['Mul'], ['Add'],
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(
            len(res) == 2
            and all([res[x].id in ['add_1', 'add_2']
                     for x in range(len(res))]),
            'Add operations was not found by bfs')

        res = common_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['Add'],
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')

        res = common_bfs(Node(graph, 'concat_1'), [], ['Add'],
                         allowed_all=True,
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(
            len(res) == 2
            and all([res[x].id in ['add_1', 'add_2']
                     for x in range(len(res))]),
            'Add operations was not found by bfs')

        res = common_bfs(Node(graph, 'concat_1'), ['ScaleShift'],
                         ['ScaleShift'],
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(
            len(res) == 0,
            'No one node should be found! But bfs found {} nodes'.format(
                len(res)))
Ejemplo n.º 3
0
    def parameter_unchanged_after_iteration(loop_node: Node, body_parameter: Node):
        """
        Checks if the body Parameter node is connected to some body Result and the data provided to Result is not
        changed between iterations. The data is considered unchanged if:
        1. There is no back edge for this Parameter OR
        2. There is a back edge from some Result to Parameter and there are only Identity ops in between or
           Parameter is connected to Result directly.

        :param loop_node: the Loop node to check
        :param body_parameter: the body Parameter node
        :return: the result of the check
        """
        assert body_parameter.id in loop_node.body
        assert body_parameter.soft_get('op') == 'Parameter'
        if not any([attr['to_layer'] == body_parameter.soft_get('internal_layer_id') for attr in loop_node.back_edges]):
            return True

        for back_edge_attrs in loop_node.back_edges:
            if back_edge_attrs['to_layer'] == body_parameter.soft_get('internal_layer_id'):
                result_internal_id = back_edge_attrs['from_layer']
                result_nodes = loop_node.body.get_op_nodes(internal_layer_id=result_internal_id)
                assert len(result_nodes) == 1, 'There should be exactly one node with id {}, but there are {}' \
                                               ''.format(result_internal_id, len(result_nodes))
                result_node = result_nodes[0]
                # check that the Result node consumes data from Parameter node directly or through Identity operations
                parameters = common_bfs(result_node, ['Identity'], ['Parameter'], is_backward=True, attr_to_check='op',
                                        follow_multi_consumer_data_nodes=True)
                if any([node.soft_get('internal_layer_id') == body_parameter.internal_layer_id for node in parameters]):
                    return True
        return False