Пример #1
0
    def facet_integral_predicates(self, mesh, integral_type, kinfo):
        self.bag.needs_cell_facets = True
        # Number of recerence cell facets
        if mesh.cell_set._extruded:
            self.num_facets = mesh._base_mesh.ufl_cell().num_facets()
        else:
            self.num_facets = mesh.ufl_cell().num_facets()

        # Index for loop over cell faces of reference cell
        fidx = self.bag.index_creator((self.num_facets,))

        # Cell is interior or exterior
        select = 1 if integral_type.startswith("interior_facet") else 0

        i = self.bag.index_creator((1,))
        predicates = [pym.Comparison(pym.Subscript(pym.Variable(self.cell_facets_arg), (fidx[0], 0)), "==", select)]

        # TODO subdomain boundary integrals, this does the wrong thing for integrals like f*ds + g*ds(1)
        # "otherwise" is treated incorrectly as "everywhere"
        # However, this replicates an existing slate bug.
        if kinfo.subdomain_id != "otherwise":
            predicates.append(pym.Comparison(pym.Subscript(pym.Variable(self.cell_facets_arg), (fidx[0], 1)), "==", kinfo.subdomain_id))

        # Additional facet array argument to be fed into tsfc loopy kernel
        subscript = pym.Subscript(pym.Variable(self.local_facet_array_arg),
                                  (pym.Sum((i[0], fidx[0]))))
        facet_arg = SubArrayRef(i, subscript)

        return predicates, fidx, facet_arg
Пример #2
0
    def layer_integral_predicates(self, tensor, integral_type):
        self.bag.needs_mesh_layers = True
        layer = pym.Variable(self.layer_arg)

        # TODO: Variable layers
        nlayer = pym.Variable(self.layer_count)
        which = {"interior_facet_horiz_top": pym.Comparison(layer, "<", nlayer),
                 "interior_facet_horiz_bottom": pym.Comparison(layer, ">", 0),
                 "exterior_facet_top": pym.Comparison(layer, "==", nlayer),
                 "exterior_facet_bottom": pym.Comparison(layer, "==", 0)}[integral_type]

        return [which]
Пример #3
0
    def map_subscript(self, expr, *args):
        dc = self.diff_context

        if expr.aggregate.name == dc.by_name:
            index = expr.index
            if not isinstance(expr.index, tuple):
                index = (expr.index, )

            assert len(self.diff_inames) == len(index)

            conds = [
                p.Comparison(var(ti), "==", ei)
                for ti, ei in zip(self.diff_inames, index)
            ]

            if len(conds) == 1:
                and_conds, = conds
            elif len(conds) > 1:
                and_conds = p.LogicalAnd(tuple(conds))
            else:
                assert False

            return p.If(and_conds, 1, 0)

        else:
            dvar_dby = dc.get_diff_var(expr.aggregate.name)
            if dvar_dby is None:
                return 0

            idx = expr.index
            if not isinstance(idx, tuple):
                idx = (idx, )

            return type(expr)(var(dvar_dby), expr.index + self.diff_inames)
Пример #4
0
def nasa7_conditional(t, poly, part_gen):
    # FIXME: Should check minTemp, maxTemp
    return p.If(
        p.Comparison(t, ">", poly.coeffs[0]),
        part_gen(poly.coeffs[1:8], t),
        part_gen(poly.coeffs[8:15], t),
    )
Пример #5
0
 def upper_half_square_root(x):
     return p.If(
             p.Comparison(
                 (x**0.5).a.imag,
                 "<", 0),
             1j*(-x)**0.5,
             x**0.5)
Пример #6
0
def test_sympy_if_condition():
    pytest.importorskip("sympy")
    from pymbolic.interop.sympy import PymbolicToSympyMapper, SympyToPymbolicMapper
    forward = PymbolicToSympyMapper()
    backward = SympyToPymbolicMapper()

    # Test round trip to sympy and back
    expr = prim.If(prim.Comparison(x_, "<=", y_), 1, 0)
    assert backward(forward(expr)) == expr
Пример #7
0
    def map_Compare(self, expr):  # noqa
        # (expr left, cmpop* ops, expr* comparators)
        op, = expr.ops

        try:
            comp = self.comparison_op_map[type(op)]
        except KeyError:
            raise NotImplementedError(
                "%s does not know how to map operator '%s'" %
                (type(self).__name__, type(op).__name__))

        # FIXME: Support strung-together comparisons
        right, = expr.comparators

        return p.Comparison(self.rec(expr.left), comp, self.rec(right))
Пример #8
0
    def emit_sequential_loop(self, codegen_state, iname, iname_dtype, lbound,
                             ubound, inner):
        ecm = codegen_state.expression_to_code_mapper

        from loopy.target.c import POD

        from pymbolic.mapper.stringifier import PREC_NONE
        from cgen import For, InlineInitializer

        from cgen.ispc import ISPCUniform

        return For(
            InlineInitializer(ISPCUniform(POD(self, iname_dtype, iname)),
                              ecm(lbound, PREC_NONE, "i")),
            ecm(p.Comparison(var(iname), "<=", ubound), PREC_NONE, "i"),
            "++%s" % iname, inner)
Пример #9
0
def substitute_into_domain(domain, param_name, expr, allowed_param_dims):
    """
    :arg allowed_deps: A :class:`list` of :class:`str` that are
    """
    import pymbolic.primitives as prim
    from loopy.symbolic import get_dependencies, isl_set_from_expr
    if param_name not in domain.get_var_dict():
        # param_name not in domain => domain will be unchanged
        return domain

    # {{{ rename 'param_name' to avoid namespace pollution with allowed_param_dims

    dt, pos = domain.get_var_dict()[param_name]
    domain = domain.set_dim_name(
        dt, pos,
        UniqueNameGenerator(set(allowed_param_dims))(param_name))

    # }}}

    for dep in get_dependencies(expr):
        if dep in allowed_param_dims:
            domain = domain.add_dims(isl.dim_type.param, 1)
            domain = domain.set_dim_name(isl.dim_type.param,
                                         domain.dim(isl.dim_type.param) - 1,
                                         dep)
        else:
            raise ValueError("Augmenting caller's domain "
                             f"with '{dep}' is not allowed.")

    set_ = isl_set_from_expr(
        domain.space, prim.Comparison(prim.Variable(param_name), "==", expr))

    bset, = set_.get_basic_sets()
    domain = domain & bset

    return domain.project_out(dt, pos, 1)
Пример #10
0
def expression_comparison(expr, parameters):
    l, r = (expression(c, parameters) for c in expr.children)
    return pym.Comparison(l, expr.operator, r)
Пример #11
0
def _expression_comparison(expr, ctx):
    left, right = [expression(c, ctx) for c in expr.children]
    return p.Comparison(left, expr.operator, right)
Пример #12
0
def loopy_inst_compare(expr, context):
    left, right = [loopy_instructions(o, context) for o in expr.ufl_operands]
    op = expr._name
    return p.Comparison(left, op, right)
Пример #13
0
 def _comparison_operator(self, expr, operator=None):
     left = self.rec(expr.args[0])
     right = self.rec(expr.args[1])
     return prim.Comparison(left, operator, right)