예제 #1
0
def make_horizontal_loop_with_init(field: Str):
    write_access = nir.FieldAccess(
        name=field,
        primary=no_extent,
        location_type=default_location,
    )
    return (
        nir.HorizontalLoop(
            stmt=nir.BlockStmt(
                declarations=[],
                statements=[
                    nir.AssignStmt(
                        left=write_access,
                        right=nir.Literal(
                            value=common.BuiltInLiteral.ONE,
                            vtype=default_vtype,
                            location_type=default_location,
                        ),
                    )
                ],
            ),
            location_type=default_location,
        ),
        write_access,
    )
예제 #2
0
    def visit_VerticalLoop(self, node: nir.VerticalLoop, *,
                           merge_candidates: List[List[nir.HorizontalLoop]],
                           **kwargs):
        for candidate in merge_candidates:
            declarations = []
            statements = []
            location_type = candidate[0].location_type

            first_index = node.horizontal_loops.index(candidate[0])
            last_index = node.horizontal_loops.index(candidate[-1])

            for loop in candidate:
                declarations += loop.stmt.declarations
                statements += loop.stmt.statements

            node.horizontal_loops[first_index:last_index + 1] = [  # noqa: E203
                nir.HorizontalLoop(
                    stmt=nir.BlockStmt(
                        declarations=declarations,
                        statements=statements,
                        location_type=location_type,
                    ),
                    location_type=location_type,
                )
            ]

        return node
예제 #3
0
def make_horizontal_loop_with_copy(write: Str, read: Str,
                                   read_has_extent: Bool):
    write_access = nir.FieldAccess(
        name=write,
        primary=no_extent,
        location_type=default_location,
    )
    read_access = nir.FieldAccess(
        name=read,
        primary=with_extent if read_has_extent else no_extent,
        location_type=default_location,
    )

    return (
        nir.HorizontalLoop(
            stmt=nir.BlockStmt(
                declarations=[],
                statements=[
                    nir.AssignStmt(left=write_access, right=read_access)
                ],
            ),
            location_type=default_location,
        ),
        write_access,
        read_access,
    )
예제 #4
0
def make_block_stmt(stmts: List[nir.Stmt], declarations: List[nir.LocalVar]):
    return nir.BlockStmt(
        location_type=stmts[0].location_type
        if len(stmts) > 0 else common.LocationType.Vertex,
        statements=stmts,
        declarations=declarations,
    )
예제 #5
0
    def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block,
                             **kwargs):
        loc_comprehension = copy.deepcopy(kwargs["location_comprehensions"])
        assert node.neighbors.name not in loc_comprehension
        loc_comprehension[node.neighbors.name] = node.neighbors
        kwargs["location_comprehensions"] = loc_comprehension

        body_location = node.neighbors.chain.elements[-1]
        reduce_var_name = "local" + str(node.id_)
        last_block.declarations.append(
            nir.LocalVar(
                name=reduce_var_name,
                vtype=common.DataType.FLOAT64,  # TODO
                location_type=node.location_type,
            ))
        last_block.statements.append(
            nir.AssignStmt(
                left=nir.VarAccess(name=reduce_var_name,
                                   location_type=node.location_type),
                right=nir.Literal(
                    value=self.REDUCE_OP_INIT_VAL[node.op],
                    location_type=node.location_type,
                    vtype=common.DataType.FLOAT64,  # TODO
                ),
                location_type=node.location_type,
            ), )
        body = nir.BlockStmt(
            declarations=[],
            statements=[
                nir.AssignStmt(
                    left=nir.VarAccess(name=reduce_var_name,
                                       location_type=body_location),
                    right=nir.BinaryOp(
                        left=nir.VarAccess(name=reduce_var_name,
                                           location_type=body_location),
                        op=self.REDUCE_OP_TO_BINOP[node.op],
                        right=self.visit(node.operand,
                                         in_neighbor_loop=True,
                                         **kwargs),
                        location_type=body_location,
                    ),
                    location_type=body_location,
                )
            ],
            location_type=body_location,
        )
        last_block.statements.append(
            nir.NeighborLoop(
                neighbors=self.visit(node.neighbors.chain),
                body=body,
                location_type=node.location_type,
            ))
        return nir.VarAccess(name=reduce_var_name,
                             location_type=node.location_type)  # TODO
예제 #6
0
 def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, **kwargs):
     block = nir.BlockStmt(declarations=[],
                           statements=[],
                           location_type=node.stmt.location_type)
     stmt = self.visit(
         node.stmt,
         last_block=block,
         location_comprehensions={node.location.name: node.location})
     block.statements.append(stmt)
     return nir.HorizontalLoop(
         stmt=block,
         location_type=node.location.chain.elements[0],
     )
예제 #7
0
def make_empty_block_stmt(location_type: common.LocationType):
    return nir.BlockStmt(location_type=location_type,
                         declarations=[],
                         statements=[])