Example #1
0
 def fiat_equivalent(self):
     if self.is_mixed:
         # EnrichedElement is actually a MixedElement
         return FIAT.MixedElement(
             [e.element.fiat_equivalent for e in self.elements],
             ref_el=self.cell)
     else:
         return FIAT.EnrichedElement(*(e.fiat_equivalent
                                       for e in self.elements))
Example #2
0
def _create_vector_finiteelement(
        element: ufl.VectorElement) -> FIAT.MixedElement:
    fiat_element = FIAT.MixedElement(
        map(_create_element, element.sub_elements()))

    def reorder_for_vector_element(item, block_size):
        """Reorder the elements in item from XXYYZZ ordering to XYZXYZ."""
        space_dim = len(item) // block_size
        return [
            item[i] for block in range(space_dim)
            for i in range(block, len(item), space_dim)
        ]

    def calculate_entity_dofs_of_vector_element(entity_dofs, block_size):
        """Get the entity DOFs of a VectorElement with XYZXYZ ordering."""
        return {
            dim: {
                entity: [
                    block_size * i + j for i in e_dofs
                    for j in range(block_size)
                ]
                for entity, e_dofs in dofs.items()
            }
            for dim, dofs in entity_dofs.items()
        }

    # Reorder from XXYYZZ to XYZXYZ
    block_size = fiat_element.num_sub_elements()
    fiat_element.mapping = types.MethodType(
        lambda self:
        [m for m in self._elements[0].mapping() for e in self._elements],
        fiat_element)
    fiat_element.dual.nodes = reorder_for_vector_element(
        fiat_element.dual.nodes, block_size)
    fiat_element.dual.entity_ids = calculate_entity_dofs_of_vector_element(
        fiat_element.elements()[0].dual.entity_ids, block_size)
    fiat_element.dual.entity_closure_ids = calculate_entity_dofs_of_vector_element(
        fiat_element.elements()[0].dual.entity_closure_ids, block_size)
    fiat_element.old_tabulate = fiat_element.tabulate

    def tabulate(self, order, points, entity=None):
        block_size = self.num_sub_elements()
        scalar_dofs = len(self.dual.nodes) // block_size
        return {
            i: numpy.array([
                item[j] for dim in range(scalar_dofs)
                for j in range(dim, len(item), scalar_dofs)
            ])
            for i, item in self.old_tabulate(order, points,
                                             entity=entity).items()
        }

    fiat_element.tabulate = types.MethodType(tabulate, fiat_element)

    return fiat_element
Example #3
0
    def fiat_equivalent(self):
        # Avoid circular import dependency
        from finat.mixed import MixedSubElement

        if all(isinstance(e, MixedSubElement) for e in self.elements):
            # EnrichedElement is actually a MixedElement
            return FIAT.MixedElement([e.element.fiat_equivalent
                                      for e in self.elements], ref_el=self.cell)
        else:
            return FIAT.EnrichedElement(*(e.fiat_equivalent
                                          for e in self.elements))
Example #4
0
def _create_mixed_finiteelement(
        element: ufl.MixedElement) -> FIAT.MixedElement:
    elements = []

    def rextract(els):
        for e in els:
            if isinstance(e, ufl.MixedElement) \
                    and not isinstance(e, ufl.VectorElement) \
                    and not isinstance(e, ufl.TensorElement):
                rextract(e.sub_elements())
            else:
                elements.append(e)

    rextract(element.sub_elements())
    return FIAT.MixedElement(map(_create_element, elements))
Example #5
0
def _(element, vector_is_mixed):
    # If we're just trying to get the scalar part of a vector element?
    if not vector_is_mixed:
        assert isinstance(element, (ufl.VectorElement, ufl.TensorElement))
        return create_element(element.sub_elements()[0], vector_is_mixed)

    elements = []

    def rec(eles):
        for ele in eles:
            if isinstance(ele, ufl.MixedElement):
                rec(ele.sub_elements())
            else:
                elements.append(ele)

    rec(element.sub_elements())
    fiat_elements = map(
        partial(create_element, vector_is_mixed=vector_is_mixed), elements)
    return FIAT.MixedElement(fiat_elements)