Beispiel #1
0
    def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var,
                           lhs_expr, rhs_expr, lhs_dtype, rhs_type_context):
        from pymbolic.mapper.stringifier import PREC_NONE

        # FIXME: Could detect operations, generate atomic_{add,...} when
        # appropriate.

        if isinstance(lhs_dtype, NumpyType) and lhs_dtype.numpy_dtype in [
                np.int32, np.int64, np.float32, np.float64
        ]:
            from cgen import Block, DoWhile, Assign
            from loopy.target.c import POD
            old_val_var = codegen_state.var_name_generator("loopy_old_val")
            new_val_var = codegen_state.var_name_generator("loopy_new_val")

            from loopy.kernel.data import TemporaryVariable, AddressSpace
            ecm = codegen_state.expression_to_code_mapper.with_assignments({
                old_val_var:
                TemporaryVariable(old_val_var, lhs_dtype),
                new_val_var:
                TemporaryVariable(new_val_var, lhs_dtype),
            })

            lhs_expr_code = ecm(lhs_expr, prec=PREC_NONE, type_context=None)

            from pymbolic.mapper.substitutor import make_subst_func
            from pymbolic import var
            from loopy.symbolic import SubstitutionMapper

            subst = SubstitutionMapper(
                make_subst_func({lhs_expr: var(old_val_var)}))
            rhs_expr_code = ecm(subst(rhs_expr),
                                prec=PREC_NONE,
                                type_context=rhs_type_context,
                                needed_dtype=lhs_dtype)

            if lhs_dtype.numpy_dtype.itemsize == 4:
                func_name = "atomic_cmpxchg"
            elif lhs_dtype.numpy_dtype.itemsize == 8:
                func_name = "atom_cmpxchg"
            else:
                raise LoopyError("unexpected atomic size")

            cast_str = ""
            old_val = old_val_var
            new_val = new_val_var

            if lhs_dtype.numpy_dtype.kind == "f":
                if lhs_dtype.numpy_dtype == np.float32:
                    ctype = "int"
                elif lhs_dtype.numpy_dtype == np.float64:
                    ctype = "long"
                else:
                    assert False

                from loopy.kernel.data import (TemporaryVariable, ArrayArg)
                if (isinstance(lhs_var, ArrayArg)
                        and lhs_var.address_space == AddressSpace.GLOBAL):
                    var_kind = "__global"
                elif (isinstance(lhs_var, ArrayArg)
                      and lhs_var.address_space == AddressSpace.LOCAL):
                    var_kind = "__local"
                elif (isinstance(lhs_var, TemporaryVariable)
                      and lhs_var.address_space == AddressSpace.LOCAL):
                    var_kind = "__local"
                elif (isinstance(lhs_var, TemporaryVariable)
                      and lhs_var.address_space == AddressSpace.GLOBAL):
                    var_kind = "__global"
                else:
                    raise LoopyError("unexpected kind of variable '%s' in "
                                     "atomic operation: " %
                                     (lhs_var.name, type(lhs_var).__name__))

                old_val = "*(%s *) &" % ctype + old_val
                new_val = "*(%s *) &" % ctype + new_val
                cast_str = "(%s %s *) " % (var_kind, ctype)

            return Block([
                POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                    old_val_var),
                POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                    new_val_var),
                DoWhile(
                    "%(func_name)s("
                    "%(cast_str)s&(%(lhs_expr)s), "
                    "%(old_val)s, "
                    "%(new_val)s"
                    ") != %(old_val)s" % {
                        "func_name": func_name,
                        "cast_str": cast_str,
                        "lhs_expr": lhs_expr_code,
                        "old_val": old_val,
                        "new_val": new_val,
                    },
                    Block([
                        Assign(old_val_var, lhs_expr_code),
                        Assign(new_val_var, rhs_expr_code),
                    ]))
            ])
        else:
            raise NotImplementedError("atomic update for '%s'" % lhs_dtype)
Beispiel #2
0
    def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var,
                           lhs_expr, rhs_expr, lhs_dtype, rhs_type_context):

        from pymbolic.primitives import Sum
        from cgen import Statement
        from pymbolic.mapper.stringifier import PREC_NONE

        if isinstance(lhs_dtype, NumpyType) and lhs_dtype.numpy_dtype in [
                np.int32, np.int64, np.float32, np.float64
        ]:
            # atomicAdd
            if isinstance(rhs_expr, Sum):
                ecm = self.get_expression_to_code_mapper(codegen_state)

                new_rhs_expr = Sum(
                    tuple(c for c in rhs_expr.children if c != lhs_expr))
                lhs_expr_code = ecm(lhs_expr)
                rhs_expr_code = ecm(new_rhs_expr)

                return Statement("atomicAdd(&{}, {})".format(
                    lhs_expr_code, rhs_expr_code))
            else:
                from cgen import Block, DoWhile, Assign
                from loopy.target.c import POD
                old_val_var = codegen_state.var_name_generator("loopy_old_val")
                new_val_var = codegen_state.var_name_generator("loopy_new_val")

                from loopy.kernel.data import TemporaryVariable
                ecm = codegen_state.expression_to_code_mapper.with_assignments(
                    {
                        old_val_var: TemporaryVariable(old_val_var, lhs_dtype),
                        new_val_var: TemporaryVariable(new_val_var, lhs_dtype),
                    })

                lhs_expr_code = ecm(lhs_expr,
                                    prec=PREC_NONE,
                                    type_context=None)

                from pymbolic.mapper.substitutor import make_subst_func
                from pymbolic import var
                from loopy.symbolic import SubstitutionMapper

                subst = SubstitutionMapper(
                    make_subst_func({lhs_expr: var(old_val_var)}))
                rhs_expr_code = ecm(subst(rhs_expr),
                                    prec=PREC_NONE,
                                    type_context=rhs_type_context,
                                    needed_dtype=lhs_dtype)

                cast_str = ""
                old_val = old_val_var
                new_val = new_val_var

                if lhs_dtype.numpy_dtype.kind == "f":
                    if lhs_dtype.numpy_dtype == np.float32:
                        ctype = "int"
                    elif lhs_dtype.numpy_dtype == np.float64:
                        ctype = "long"
                    else:
                        raise AssertionError()

                    old_val = "*(%s *) &" % ctype + old_val
                    new_val = "*(%s *) &" % ctype + new_val
                    cast_str = "(%s *) " % (ctype)

                return Block([
                    POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                        old_val_var),
                    POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                        new_val_var),
                    DoWhile(
                        "atomicCAS("
                        "%(cast_str)s&(%(lhs_expr)s), "
                        "%(old_val)s, "
                        "%(new_val)s"
                        ") != %(old_val)s" % {
                            "cast_str": cast_str,
                            "lhs_expr": lhs_expr_code,
                            "old_val": old_val,
                            "new_val": new_val,
                        },
                        Block([
                            Assign(old_val_var, lhs_expr_code),
                            Assign(new_val_var, rhs_expr_code),
                        ]))
                ])
        else:
            raise NotImplementedError("atomic update for '%s'" % lhs_dtype)