Пример #1
0
    def visit_VerticalLoop(
        self, node: nir.VerticalLoop, *, merge_candidates: List[List[nir.HorizontalLoop]], **kwargs
    ):
        for candidate in merge_candidates:
            declarations = []
            statements = []
            location_type = candidate[0].iteration_space.location_type

            first_index = node.horizontal_loops.index(candidate[0])
            last_index = node.horizontal_loops.index(candidate[-1])

            for loop in candidate:
                declarations += loop.stmt.declarations
                statements += loop.stmt.statements

            node.horizontal_loops[first_index : last_index + 1] = [
                nir.HorizontalLoop(
                    stmt=nir.BlockStmt(
                        declarations=declarations,
                        statements=statements,
                        location_type=location_type,
                    ),
                    iteration_space=node.horizontal_loops[first_index].iteration_space,
                )
            ]

        return node
Пример #2
0
 def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, **kwargs):
     block = nir.BlockStmt(declarations=[],
                           statements=[],
                           location_type=node.stmt.location_type)
     stmt = self.visit(
         node.stmt,
         last_block=block,
         location_comprehensions={node.location.name: node.location})
     block.statements.append(stmt)
     return nir.HorizontalLoop(
         stmt=block,
         location_type=node.location.chain.elements[0],
     )
Пример #3
0
 def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, *, symtable,
                          **kwargs):
     hloop_ctx = self.HorizontalLoopContext()
     self.visit(
         node.stmt,
         hloop_ctx=hloop_ctx,
         location_comprehensions={node.location.name: node.location},
         symtable={
             **symtable,
             **node.symtable_
         },
         **kwargs,
     )
     return nir.HorizontalLoop(
         stmt=nir.BlockStmt(
             declarations=hloop_ctx.declarations,
             statements=hloop_ctx.statements,
             location_type=node.location.location_type,
         ),
         iteration_space=nir.IterationSpace(
             name=node.location.name,
             location_type=node.location.location_type),
     )
Пример #4
0
    def visit_HorizontalLoop(
        self,
        node: nir.HorizontalLoop,
        *,
        merge_groups: Dict[str, List[List[nir.NeighborLoop]]],
        **kwargs,
    ):
        assert id(node) in merge_groups
        groups: List[List[nir.NeighborLoop]] = merge_groups[id(node)]

        # the target neighbor loops where groups will be merged
        heads: List[str] = [id(group[0]) for group in groups]

        # mapping from id(NeighborLoop) to its target loop where it should be merged
        # (only for non targets)
        targets: Dict[int, nir.NeighborLoop] = {}

        # mapping from id(NeighborLoop) to the new initialization statements from the
        # merged loops to add in front of the neighbor loop
        targets_init: Dict[int, List[nir.AssignStmt]] = {}

        stmt_declarations = node.stmt.declarations
        stmt_statements = []

        num_stmts = len(node.stmt.statements)
        for i, hl_stmt in enumerate(node.stmt.statements):
            # Traverse all the statements in the horizontal loop
            if isinstance(hl_stmt, nir.NeighborLoop):
                if id(hl_stmt) in heads:
                    # If it is a target neighbor loop, create the dicts
                    # from id(NeighborLoop) to this loop, for all the other
                    # loops that will be merged into this
                    current_group = groups[heads.index(id(hl_stmt))]
                    target_n_loop = copy.deepcopy(current_group[0])
                    assert id(target_n_loop) == id(hl_stmt)
                    for other_n_loop in current_group[1:]:
                        targets[id(other_n_loop)] = target_n_loop
                    stmt_statements.append(target_n_loop)

                else:
                    # If it is a neighbor loop that should be merged,
                    # merge body into target loop
                    assert id(hl_stmt) in targets
                    target_n_loop = targets[id(hl_stmt)]
                    other_body: nir.BlockStmt = RenameSymbol.apply(
                        hl_stmt.body, hl_stmt.name, target_n_loop.name)
                    target_n_loop.body.declarations.extend(
                        other_body.declarations)
                    target_n_loop.body.statements.extend(other_body.statements)

            elif (isinstance(hl_stmt, nir.AssignStmt) and i < num_stmts - 1
                  and isinstance(node.stmt.statements[i + 1], nir.NeighborLoop)
                  and id(node.stmt.statements[i + 1]) in targets):
                # If it is the initialization statement of a reduce neighbor loop,
                # save it for later in the list of inits associated to the target loop
                targets_init.setdefault(
                    id(targets[id(node.stmt.statements[i + 1])]),
                    []).append(hl_stmt)

            else:
                # Any other statement just passes
                stmt_statements.append(hl_stmt)

        # Move the reduce initialization statements of the merged neighbor loops
        # in front of the target neighbor loop in which they were merged into.
        i = 0
        while i < len(stmt_statements):
            if id(stmt_statements[i]) in targets_init:
                offset = len(targets_init[id(stmt_statements[i])]) + 1
                for init_stmt in targets_init[id(stmt_statements[i])]:
                    stmt_statements.insert(i, init_stmt)
                i += offset
            else:
                i += 1

        return nir.HorizontalLoop(
            iteration_space=node.iteration_space,
            stmt=nir.BlockStmt(declarations=stmt_declarations,
                               statements=stmt_statements),
        )