예제 #1
0
    def visit_VerticalLoop(
        self, node: nir.VerticalLoop, *, merge_candidates: List[List[nir.HorizontalLoop]], **kwargs
    ):
        for candidate in merge_candidates:
            declarations = []
            statements = []
            location_type = candidate[0].iteration_space.location_type

            first_index = node.horizontal_loops.index(candidate[0])
            last_index = node.horizontal_loops.index(candidate[-1])

            for loop in candidate:
                declarations += loop.stmt.declarations
                statements += loop.stmt.statements

            node.horizontal_loops[first_index : last_index + 1] = [
                nir.HorizontalLoop(
                    stmt=nir.BlockStmt(
                        declarations=declarations,
                        statements=statements,
                        location_type=location_type,
                    ),
                    iteration_space=node.horizontal_loops[first_index].iteration_space,
                )
            ]

        return node
예제 #2
0
    def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block,
                             **kwargs):
        loc_comprehension = copy.deepcopy(kwargs["location_comprehensions"])
        assert node.neighbors.name not in loc_comprehension
        loc_comprehension[node.neighbors.name] = node.neighbors
        kwargs["location_comprehensions"] = loc_comprehension

        body_location = node.neighbors.chain.elements[-1]
        reduce_var_name = "local" + str(node.id_)
        last_block.declarations.append(
            nir.LocalVar(
                name=reduce_var_name,
                vtype=common.DataType.FLOAT64,  # TODO
                location_type=node.location_type,
            ))
        last_block.statements.append(
            nir.AssignStmt(
                left=nir.VarAccess(name=reduce_var_name,
                                   location_type=node.location_type),
                right=nir.Literal(
                    value=self.REDUCE_OP_INIT_VAL[node.op],
                    location_type=node.location_type,
                    vtype=common.DataType.FLOAT64,  # TODO
                ),
                location_type=node.location_type,
            ), )
        body = nir.BlockStmt(
            declarations=[],
            statements=[
                nir.AssignStmt(
                    left=nir.VarAccess(name=reduce_var_name,
                                       location_type=body_location),
                    right=nir.BinaryOp(
                        left=nir.VarAccess(name=reduce_var_name,
                                           location_type=body_location),
                        op=self.REDUCE_OP_TO_BINOP[node.op],
                        right=self.visit(node.operand,
                                         in_neighbor_loop=True,
                                         **kwargs),
                        location_type=body_location,
                    ),
                    location_type=body_location,
                )
            ],
            location_type=body_location,
        )
        last_block.statements.append(
            nir.NeighborLoop(
                neighbors=self.visit(node.neighbors.chain),
                body=body,
                location_type=node.location_type,
            ))
        return nir.VarAccess(name=reduce_var_name,
                             location_type=node.location_type)  # TODO
예제 #3
0
 def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, **kwargs):
     block = nir.BlockStmt(declarations=[],
                           statements=[],
                           location_type=node.stmt.location_type)
     stmt = self.visit(
         node.stmt,
         last_block=block,
         location_comprehensions={node.location.name: node.location})
     block.statements.append(stmt)
     return nir.HorizontalLoop(
         stmt=block,
         location_type=node.location.chain.elements[0],
     )
예제 #4
0
파일: gtir_to_nir.py 프로젝트: havogt/gt4py
 def visit_NeighborAssignStmt(
     self,
     node: gtir.NeighborAssignStmt,
     *,
     symtable,
     hloop_ctx: "HorizontalLoopContext",
     **kwargs,
 ):
     symtable = {**symtable, **node.symtable_}
     name = node.neighbors.name
     hloop_ctx.add_statement(
         nir.NeighborLoop(
             name=nir.NeighborLoopVar(name=name),
             connectivity=node.neighbors.of.name,
             body=nir.BlockStmt(
                 declarations=[],
                 statements=[
                     nir.AssignStmt(
                         left=nir.FieldAccess(
                             name=node.left.name,
                             primary=node.left.subscript[0].name,
                             secondary=node.neighbors.name,
                             location_type=node.location_type,
                         ),
                         right=self.visit(
                             node.right,
                             loop_var=name,
                             symtable=symtable,
                             hloop_ctx=hloop_ctx,
                             **kwargs,
                         ),
                         location_type=node.location_type,
                     )
                 ],
                 location_type=node.location_type,
             ),
             location_type=node.location_type,
         ))
예제 #5
0
파일: gtir_to_nir.py 프로젝트: havogt/gt4py
 def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, *, symtable,
                          **kwargs):
     hloop_ctx = self.HorizontalLoopContext()
     self.visit(
         node.stmt,
         hloop_ctx=hloop_ctx,
         location_comprehensions={node.location.name: node.location},
         symtable={
             **symtable,
             **node.symtable_
         },
         **kwargs,
     )
     return nir.HorizontalLoop(
         stmt=nir.BlockStmt(
             declarations=hloop_ctx.declarations,
             statements=hloop_ctx.statements,
             location_type=node.location.location_type,
         ),
         iteration_space=nir.IterationSpace(
             name=node.location.name,
             location_type=node.location.location_type),
     )
예제 #6
0
    def visit_HorizontalLoop(
        self,
        node: nir.HorizontalLoop,
        *,
        merge_groups: Dict[str, List[List[nir.NeighborLoop]]],
        **kwargs,
    ):
        assert id(node) in merge_groups
        groups: List[List[nir.NeighborLoop]] = merge_groups[id(node)]

        # the target neighbor loops where groups will be merged
        heads: List[str] = [id(group[0]) for group in groups]

        # mapping from id(NeighborLoop) to its target loop where it should be merged
        # (only for non targets)
        targets: Dict[int, nir.NeighborLoop] = {}

        # mapping from id(NeighborLoop) to the new initialization statements from the
        # merged loops to add in front of the neighbor loop
        targets_init: Dict[int, List[nir.AssignStmt]] = {}

        stmt_declarations = node.stmt.declarations
        stmt_statements = []

        num_stmts = len(node.stmt.statements)
        for i, hl_stmt in enumerate(node.stmt.statements):
            # Traverse all the statements in the horizontal loop
            if isinstance(hl_stmt, nir.NeighborLoop):
                if id(hl_stmt) in heads:
                    # If it is a target neighbor loop, create the dicts
                    # from id(NeighborLoop) to this loop, for all the other
                    # loops that will be merged into this
                    current_group = groups[heads.index(id(hl_stmt))]
                    target_n_loop = copy.deepcopy(current_group[0])
                    assert id(target_n_loop) == id(hl_stmt)
                    for other_n_loop in current_group[1:]:
                        targets[id(other_n_loop)] = target_n_loop
                    stmt_statements.append(target_n_loop)

                else:
                    # If it is a neighbor loop that should be merged,
                    # merge body into target loop
                    assert id(hl_stmt) in targets
                    target_n_loop = targets[id(hl_stmt)]
                    other_body: nir.BlockStmt = RenameSymbol.apply(
                        hl_stmt.body, hl_stmt.name, target_n_loop.name)
                    target_n_loop.body.declarations.extend(
                        other_body.declarations)
                    target_n_loop.body.statements.extend(other_body.statements)

            elif (isinstance(hl_stmt, nir.AssignStmt) and i < num_stmts - 1
                  and isinstance(node.stmt.statements[i + 1], nir.NeighborLoop)
                  and id(node.stmt.statements[i + 1]) in targets):
                # If it is the initialization statement of a reduce neighbor loop,
                # save it for later in the list of inits associated to the target loop
                targets_init.setdefault(
                    id(targets[id(node.stmt.statements[i + 1])]),
                    []).append(hl_stmt)

            else:
                # Any other statement just passes
                stmt_statements.append(hl_stmt)

        # Move the reduce initialization statements of the merged neighbor loops
        # in front of the target neighbor loop in which they were merged into.
        i = 0
        while i < len(stmt_statements):
            if id(stmt_statements[i]) in targets_init:
                offset = len(targets_init[id(stmt_statements[i])]) + 1
                for init_stmt in targets_init[id(stmt_statements[i])]:
                    stmt_statements.insert(i, init_stmt)
                i += offset
            else:
                i += 1

        return nir.HorizontalLoop(
            iteration_space=node.iteration_space,
            stmt=nir.BlockStmt(declarations=stmt_declarations,
                               statements=stmt_statements),
        )
예제 #7
0
파일: gtir_to_nir.py 프로젝트: havogt/gt4py
    def visit_NeighborReduce(self, node: gtir.NeighborReduce, *,
                             hloop_ctx: "HorizontalLoopContext", **kwargs):
        connectivity_deref: gtir.Connectivity = kwargs["symtable"][
            node.neighbors.of.name]

        reduce_var_name = "local" + str(id(node))
        hloop_ctx.add_declaration(
            nir.LocalVar(
                name=reduce_var_name,
                vtype=common.DataType.FLOAT64,  # TODO
                location_type=node.location_type,
            ))
        hloop_ctx.add_statement(
            nir.AssignStmt(
                left=nir.VarAccess(name=reduce_var_name,
                                   location_type=node.location_type),
                right=nir.Literal(
                    value=self.REDUCE_OP_INIT_VAL[node.op],
                    location_type=node.location_type,
                    vtype=common.DataType.FLOAT64,  # TODO
                ),
                location_type=node.location_type,
            ), )
        body_location = connectivity_deref.secondary
        op = self.REDUCE_OP_TO_BINOP[node.op]
        if op == common.BuiltInLiteral.MIN_VALUE or op == common.BuiltInLiteral.MAX_VALUE:
            right = nir.NativeFuncCall(
                func=common.NativeFunction.MAX
                if op == common.BuiltInLiteral.MAX_VALUE else
                common.NativeFunction.MIN,
                args=[
                    nir.VarAccess(name=reduce_var_name,
                                  location_type=body_location),
                    self.visit(node.operand, in_neighbor_loop=True, **kwargs),
                ],
                location_type=body_location,
            )
        else:
            right = nir.BinaryOp(
                left=nir.VarAccess(name=reduce_var_name,
                                   location_type=body_location),
                op=op,
                right=self.visit(node.operand, in_neighbor_loop=True,
                                 **kwargs),
                location_type=body_location,
            )
        body = nir.BlockStmt(
            declarations=[],
            statements=[
                nir.AssignStmt(
                    left=nir.VarAccess(name=reduce_var_name,
                                       location_type=body_location),
                    right=right,
                    location_type=body_location,
                )
            ],
            location_type=body_location,
        )
        hloop_ctx.add_statement(
            nir.NeighborLoop(
                name=nir.NeighborLoopVar(name=node.neighbors.name),
                connectivity=connectivity_deref.name,
                body=body,
                location_type=node.location_type,
            ))
        return nir.VarAccess(name=reduce_var_name,
                             location_type=node.location_type)  # TODO