Ejemplo n.º 1
0
def _compute_dofmap_ir(ufl_element, element_numbers, dofmap_names):
    """Compute intermediate representation of dofmap."""

    logger.info("Computing IR for dofmap of {}".format(ufl_element))

    # Create FIAT element
    fiat_element = create_element(ufl_element)

    # Precompute repeatedly used items
    num_dofs_per_entity = _num_dofs_per_entity(fiat_element)
    entity_dofs = fiat_element.entity_dofs()

    # Store id
    ir = {"id": element_numbers[ufl_element]}
    ir["name"] = dofmap_names[ufl_element]

    # Compute data for each function
    ir["signature"] = "FFCX dofmap for " + repr(ufl_element)
    ir["num_global_support_dofs"] = _num_global_support_dofs(fiat_element)
    ir["num_element_support_dofs"] = fiat_element.space_dimension(
    ) - ir["num_global_support_dofs"]
    ir["num_entity_dofs"] = num_dofs_per_entity
    ir["tabulate_entity_dofs"] = (entity_dofs, num_dofs_per_entity)
    ir["num_sub_dofmaps"] = ufl_element.num_sub_elements()
    ir["create_sub_dofmap"] = [
        dofmap_names[e] for e in ufl_element.sub_elements()
    ]
    ir["dof_types"] = [i.functional_type for i in fiat_element.dual_basis()]
    ir["base_permutations"] = dof_permutations.base_permutations(ufl_element)
    ir["dof_reflection_entities"] = dof_permutations.reflection_entities(
        ufl_element)

    return ir_dofmap(**ir)
Ejemplo n.º 2
0
def _compute_element_ir(ufl_element, element_numbers, finite_element_names,
                        epsilon):
    """Compute intermediate representation of element."""

    logger.info("Computing IR for element {}".format(ufl_element))

    # Create FIAT element
    fiat_element = create_element(ufl_element)
    cell = ufl_element.cell()
    cellname = cell.cellname()

    # Store id
    ir = {"id": element_numbers[ufl_element]}
    ir["name"] = finite_element_names[ufl_element]

    # Compute data for each function
    ir["signature"] = repr(ufl_element)
    ir["cell_shape"] = cellname
    ir["topological_dimension"] = cell.topological_dimension()
    ir["geometric_dimension"] = cell.geometric_dimension()
    ir["space_dimension"] = fiat_element.space_dimension()
    ir["value_shape"] = ufl_element.value_shape()
    ir["reference_value_shape"] = ufl_element.reference_value_shape()

    ir["degree"] = ufl_element.degree()
    ir["family"] = ufl_element.family()

    ir["evaluate_basis"] = _evaluate_basis(ufl_element, fiat_element, epsilon)
    ir["evaluate_dof"] = _evaluate_dof(ufl_element, fiat_element)
    ir["tabulate_dof_coordinates"] = _tabulate_dof_coordinates(
        ufl_element, fiat_element)
    ir["num_sub_elements"] = ufl_element.num_sub_elements()
    ir["create_sub_element"] = [
        finite_element_names[e] for e in ufl_element.sub_elements()
    ]

    if isinstance(ufl_element, ufl.VectorElement) or isinstance(
            ufl_element, ufl.TensorElement):
        ir["block_size"] = ufl_element.num_sub_elements()
        ufl_element = ufl_element.sub_elements()[0]
        fiat_element = create_element(ufl_element)
    else:
        ir["block_size"] = 1

    ir["base_permutations"] = dof_permutations.base_permutations(ufl_element)
    ir["dof_reflection_entities"] = dof_permutations.reflection_entities(
        ufl_element)
    ir["dof_face_tangents"] = dof_permutations.face_tangents(ufl_element)

    ir["dof_types"] = [i.functional_type for i in fiat_element.dual_basis()]
    ir["entity_dofs"] = fiat_element.entity_dofs()

    return ir_element(**ir)
Ejemplo n.º 3
0
def compute_integral_ir(cell, integral_type, entitytype, integrands,
                        argument_shape, p, visualise):
    # The intermediate representation dict we're building and returning
    # here
    ir = {}

    # Pass on parameters for consumption in code generation
    ir["params"] = p

    # Shared unique tables for all quadrature loops
    ir["unique_tables"] = {}
    ir["unique_table_types"] = {}

    ir["integrand"] = {}

    ir["table_dofmaps"] = {}
    ir["table_dof_face_tangents"] = {}
    ir["table_dof_reflection_entities"] = {}

    for quadrature_rule, integrand in integrands.items():

        expression = integrand

        # Rebalance order of nested terminal modifiers
        expression = balance_modifiers(expression)

        # Remove QuadratureWeight terminals from expression and replace with 1.0
        expression = replace_quadratureweight(expression)

        # Build initial scalar list-based graph representation
        S = build_scalar_graph(expression)

        # Build terminal_data from V here before factorization. Then we
        # can use it to derive table properties for all modified
        # terminals, and then use that to rebuild the scalar graph more
        # efficiently before argument factorization. We can build
        # terminal_data again after factorization if that's necessary.

        initial_terminals = {
            i: analyse_modified_terminal(v['expression'])
            for i, v in S.nodes.items()
            if is_modified_terminal(v['expression'])
        }

        (unique_tables, unique_table_types, unique_table_num_dofs,
         mt_unique_table_reference,
         table_origins) = build_optimized_tables(quadrature_rule,
                                                 cell,
                                                 integral_type,
                                                 entitytype,
                                                 initial_terminals.values(),
                                                 ir["unique_tables"],
                                                 rtol=p["table_rtol"],
                                                 atol=p["table_atol"])

        for k, v in table_origins.items():
            ir["table_dof_face_tangents"][k] = dof_permutations.face_tangents(
                v[0])
            ir["table_dof_reflection_entities"][
                k] = dof_permutations.reflection_entities(v[0])

        for td in mt_unique_table_reference.values():
            ir["table_dofmaps"][td.name] = td.dofmap

        S_targets = [i for i, v in S.nodes.items() if v.get('target', False)]

        if 'zeros' in unique_table_types.values() and len(S_targets) == 1:
            # If there are any 'zero' tables, replace symbolically and rebuild graph
            #
            # TODO: Implement zero table elimination for non-scalar graphs
            for i, mt in initial_terminals.items():
                # Set modified terminals with zero tables to zero
                tr = mt_unique_table_reference.get(mt)
                if tr is not None and tr.ttype == "zeros":
                    S.nodes[i]['expression'] = ufl.as_ufl(0.0)

            # Propagate expression changes using dependency list
            for i, v in S.nodes.items():
                deps = [S.nodes[j]['expression'] for j in S.out_edges[i]]
                if deps:
                    v['expression'] = v['expression']._ufl_expr_reconstruct_(
                        *deps)

            # Rebuild scalar target expressions and graph (this may be
            # overkill and possible to optimize away if it turns out to be
            # costly)
            expression = S.nodes[S_targets[0]]['expression']

            # Rebuild scalar list-based graph representation
            S = build_scalar_graph(expression)

        # Output diagnostic graph as pdf
        if visualise:
            visualise_graph(S, 'S.pdf')

        # Compute factorization of arguments
        rank = len(argument_shape)
        F = compute_argument_factorization(S, rank)

        # Get the 'target' nodes that are factors of arguments, and insert in dict
        FV_targets = [i for i, v in F.nodes.items() if v.get('target', False)]
        argument_factorization = {}

        for fi in FV_targets:
            # Number of blocks using this factor must agree with number of components
            # to which this factor contributes. I.e. there are more blocks iff there are more
            # components
            assert len(F.nodes[fi]['target']) == len(F.nodes[fi]['component'])

            k = 0
            for w in F.nodes[fi]['target']:
                comp = F.nodes[fi]['component'][k]
                argument_factorization[w] = argument_factorization.get(w, [])

                # Store tuple of (factor index, component index)
                argument_factorization[w].append((fi, comp))
                k += 1

        # Get list of indices in F which are the arguments (should be at start)
        argkeys = set()
        for w in argument_factorization:
            argkeys = argkeys | set(w)
        argkeys = list(argkeys)

        # Output diagnostic graph as pdf
        if visualise:
            visualise_graph(F, 'F.pdf')

        # Build set of modified_terminals for each mt factorized vertex in F
        # and attach tables, if appropriate
        for i, v in F.nodes.items():
            expr = v['expression']
            if is_modified_terminal(expr):
                mt = analyse_modified_terminal(expr)
                F.nodes[i]['mt'] = mt
                tr = mt_unique_table_reference.get(mt)
                if tr is not None:
                    F.nodes[i]['tr'] = tr

        # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying'
        analyse_dependencies(F, mt_unique_table_reference)

        # Loop over factorization terms
        block_contributions = collections.defaultdict(list)
        for ma_indices, fi_ci in sorted(argument_factorization.items()):
            # Get a bunch of information about this term
            assert rank == len(ma_indices)
            trs = tuple(F.nodes[ai]['tr'] for ai in ma_indices)

            unames = tuple(tr.name for tr in trs)
            ttypes = tuple(tr.ttype for tr in trs)
            assert not any(tt == "zeros" for tt in ttypes)

            blockmap = tuple(tr.dofmap for tr in trs)
            block_is_uniform = all(tr.is_uniform for tr in trs)

            # Collect relevant restrictions to identify blocks correctly
            # in interior facet integrals
            block_restrictions = []
            for i, ai in enumerate(ma_indices):
                if trs[i].is_uniform:
                    r = None
                else:
                    r = F.nodes[ai]['mt'].restriction

                block_restrictions.append(r)
            block_restrictions = tuple(block_restrictions)

            # Check if each *each* factor corresponding to this argument is piecewise
            all_factors_piecewise = all(
                F.nodes[ifi[0]]["status"] == 'piecewise' for ifi in fi_ci)
            block_is_permuted = False
            for n in unames:
                if unique_tables[n].shape[0] > 1:
                    block_is_permuted = True
            ma_data = []
            for i, ma in enumerate(ma_indices):
                ma_data.append(ma_data_t(ma, trs[i]))

            block_is_transposed = False  # FIXME: Handle transposes for these block types

            block_unames = unames
            blockdata = block_data_t(ttypes, fi_ci, all_factors_piecewise,
                                     block_unames, block_restrictions,
                                     block_is_transposed, block_is_uniform,
                                     None, tuple(ma_data), None,
                                     block_is_permuted)

            # Insert in expr_ir for this quadrature loop
            block_contributions[blockmap].append(blockdata)

        # Figure out which table names are referenced
        active_table_names = set()
        for i, v in F.nodes.items():
            tr = v.get('tr')
            if tr is not None and F.nodes[i]['status'] != 'inactive':
                active_table_names.add(tr.name)

        # Figure out which table names are referenced in blocks
        for blockmap, contributions in itertools.chain(
                block_contributions.items()):
            for blockdata in contributions:
                for mad in blockdata.ma_data:
                    active_table_names.add(mad.tabledata.name)

        # Record all table types before dropping tables
        ir["unique_table_types"].update(unique_table_types)

        # Drop tables not referenced from modified terminals
        # and tables of zeros and ones
        unused_ttypes = ("zeros", "ones")
        keep_table_names = set()
        for name in active_table_names:
            ttype = ir["unique_table_types"][name]
            if ttype not in unused_ttypes:
                if name in unique_tables:
                    keep_table_names.add(name)
        unique_tables = {
            name: unique_tables[name]
            for name in keep_table_names
        }

        # Add to global set of all tables
        for name, table in unique_tables.items():
            tbl = ir["unique_tables"].get(name)
            if tbl is not None and not numpy.allclose(
                    tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]):
                raise RuntimeError("Table values mismatch with same name.")
        ir["unique_tables"].update(unique_tables)

        # Analyse active terminals to check what we'll need to generate code for
        active_mts = []
        for i, v in F.nodes.items():
            mt = v.get('mt', False)
            if mt and F.nodes[i]['status'] != 'inactive':
                active_mts.append(mt)

        # Build IR dict for the given expressions
        # Store final ir for this num_points
        ir["integrand"][quadrature_rule] = {
            "factorization": F,
            "modified_arguments": [F.nodes[i]['mt'] for i in argkeys],
            "block_contributions": block_contributions
        }
    return ir