Example #1
0
def translate_coefficient(terminal, mt, ctx):
    vec = ctx.coefficient(terminal, mt.restriction)

    if terminal.ufl_element().family() == 'Real':
        assert mt.local_derivatives == 0
        return vec

    element = ctx.create_element(terminal.ufl_element())

    # Collect FInAT tabulation for all entities
    per_derivative = collections.defaultdict(list)
    for entity_id in ctx.entity_ids:
        finat_dict = ctx.basis_evaluation(element, mt.local_derivatives, entity_id)
        for alpha, table in finat_dict.items():
            # Filter out irrelevant derivatives
            if sum(alpha) == mt.local_derivatives:
                # A numerical hack that FFC used to apply on FIAT
                # tables still lives on after ditching FFC and
                # switching to FInAT.
                table = ffc_rounding(table, ctx.epsilon)
                per_derivative[alpha].append(table)

    # Merge entity tabulations for each derivative
    if len(ctx.entity_ids) == 1:
        def take_singleton(xs):
            x, = xs  # asserts singleton
            return x
        per_derivative = {alpha: take_singleton(tables)
                          for alpha, tables in per_derivative.items()}
    else:
        f = ctx.entity_number(mt.restriction)
        per_derivative = {alpha: gem.select_expression(tables, f)
                          for alpha, tables in per_derivative.items()}

    # Coefficient evaluation
    ctx.index_cache.setdefault(terminal.ufl_element(), element.get_indices())
    beta = ctx.index_cache[terminal.ufl_element()]
    zeta = element.get_value_indices()
    vec_beta, = gem.optimise.remove_componenttensors([gem.Indexed(vec, beta)])
    value_dict = {}
    for alpha, table in per_derivative.items():
        table_qi = gem.Indexed(table, beta + zeta)
        summands = []
        for var, expr in unconcatenate([(vec_beta, table_qi)], ctx.index_cache):
            value = gem.IndexSum(gem.Product(expr, var), var.index_ordering())
            summands.append(gem.optimise.contraction(value))
        optimised_value = gem.optimise.make_sum(summands)
        value_dict[alpha] = gem.ComponentTensor(optimised_value, zeta)

    # Change from FIAT to UFL arrangement
    result = fiat_to_ufl(value_dict, mt.local_derivatives)
    assert result.shape == mt.expr.ufl_shape
    assert set(result.free_indices) <= set(ctx.point_indices)

    # Detect Jacobian of affine cells
    if not result.free_indices and all(numpy.count_nonzero(node.array) <= 2
                                       for node in traversal((result,))
                                       if isinstance(node, gem.Literal)):
        result = gem.optimise.aggressive_unroll(result)
    return result
Example #2
0
    def entity_selector(self, callback, restriction):
        """Selects code for the correct entity at run-time.  Callback
        generates code for a specified entity.

        This function passes ``callback`` the entity number.

        :arg callback: A function to be called with an entity number
                       that generates code for that entity.
        :arg restriction: Restriction of the modified terminal, used
                          for entity selection.
        """
        if len(self.entity_ids) == 1:
            return callback(self.entity_ids[0])
        else:
            f = self.entity_number(restriction)
            return gem.select_expression(list(map(callback, self.entity_ids)), f)
Example #3
0
File: fem.py Project: inducer/tsfc
    def entity_selector(self, callback, restriction):
        """Selects code for the correct entity at run-time.  Callback
        generates code for a specified entity.

        This function passes ``callback`` the entity number.

        :arg callback: A function to be called with an entity number
                       that generates code for that entity.
        :arg restriction: Restriction of the modified terminal, used
                          for entity selection.
        """
        if len(self.entity_ids) == 1:
            return callback(self.entity_ids[0])
        else:
            f = self.entity_number(restriction)
            return gem.select_expression(list(map(callback, self.entity_ids)), f)