Ejemplo n.º 1
0
def reconstruct_element(element, cell=None):
    """Rebuild element with a new cell."""
    if cell is None:
        return element
    if isinstance(element, ufl.FiniteElement):
        family = element.family()
        degree = element.degree()
        return ufl.FiniteElement(family, cell, degree)
    if isinstance(element, ufl.VectorElement):
        family = element.family()
        degree = element.degree()
        dim = len(element.sub_elements())
        return ufl.VectorElement(family, cell, degree, dim)
    if isinstance(element, ufl.TensorElement):
        family = element.family()
        degree = element.degree()
        shape = element.value_shape()
        symmetry = element.symmetry()
        return ufl.TensorElement(family, cell, degree, shape, symmetry)
    if isinstance(element, ufl.EnrichedElement):
        eles = [
            reconstruct_element(sub, cell=cell) for sub in element._elements
        ]
        return ufl.EnrichedElement(*eles)
    if isinstance(element, ufl.RestrictedElement):
        return ufl.RestrictedElement(
            reconstruct_element(element.sub_element(), cell=cell),
            element.restriction_domain())
    if isinstance(element,
                  (ufl.TraceElement, ufl.InteriorElement, ufl.HDivElement,
                   ufl.HCurlElement, ufl.BrokenElement, ufl.FacetElement)):
        return type(element)(reconstruct_element(element._element, cell=cell))
    if isinstance(element, ufl.OuterProductElement):
        return ufl.OuterProductElement(element._A, element._B, cell=cell)
    if isinstance(element, ufl.OuterProductVectorElement):
        dim = len(element.sub_elements())
        return ufl.OuterProductVectorElement(element._A,
                                             element._B,
                                             cell=cell,
                                             dim=dim)
    if isinstance(element, ufl.OuterProductTensorElement):
        return element.reconstruct(cell=cell)
    if isinstance(element, ufl.MixedElement):
        eles = [
            reconstruct_element(sub, cell=cell)
            for sub in element.sub_elements()
        ]
        return ufl.MixedElement(*eles)
    raise NotImplementedError(
        "Don't know how to reconstruct element of type %s" % type(element))
Ejemplo n.º 2
0
def _extract_elements(ufl_element, restriction_domain=None):
    """Recursively extract un-nested list of (component) elements."""

    elements = []
    if isinstance(ufl_element, ufl.MixedElement):
        for sub_element in ufl_element.sub_elements():
            elements += _extract_elements(sub_element, restriction_domain)
        return elements

    # Handle restricted elements since they might be mixed elements too.
    if isinstance(ufl_element, ufl.RestrictedElement):
        base_element = ufl_element.sub_element()
        restriction_domain = ufl_element.restriction_domain()
        return _extract_elements(base_element, restriction_domain)

    if restriction_domain:
        ufl_element = ufl.RestrictedElement(ufl_element, restriction_domain)

    elements += [create_element(ufl_element)]

    return elements