Example #1
0
 def getops(e):
     """Get a modifiable list of operands of e, optionally treating modified terminals as a unit."""
     # TODO: Maybe use e._ufl_is_terminal_modifier_
     if e._ufl_is_terminal_ or (skip_terminal_modifiers
                                and is_modified_terminal(e)):
         return []
     else:
         return list(e.ufl_operands)
Example #2
0
def replace_quadratureweight(expression):
    """Remove any QuadratureWeight terminals and replace with 1.0."""
    r = []
    for node in ufl.corealg.traversal.unique_pre_traversal(expression):
        if is_modified_terminal(node) and isinstance(node, QuadratureWeight):
            r.append(node)

    replace_map = {q: 1.0 for q in r}
    return ufl.algorithms.replace(expression, replace_map)
Example #3
0
def _find_terminals_in_ufl_expression(e, etype):
    """Recursively search expression for terminals of type etype."""
    r = []
    for op in e.ufl_operands:
        if is_modified_terminal(op) and isinstance(op, etype):
            r.append(op)
        else:
            r += _find_terminals_in_ufl_expression(op, etype)

    return r
Example #4
0
def rebuild_with_scalar_subexpressions(G):
    """Build a new expression2index mapping where each subexpression is scalar valued.

    Input:
    - G.e2i
    - G.V
    - G.V_symbols
    - G.total_unique_symbols

    Output:
    - NV   - Array with reverse mapping from index to expression
    - nvs  - Tuple of ne2i indices corresponding to the last vertex of G.V
    """

    # Compute symbols over graph and rebuild scalar expression
    #
    # New expression which represents usually an algebraic operation
    # generates a new symbol
    value_numberer = ValueNumberer(G)

    # V_symbols maps an index of a node to a list of
    # symbols which are present in that node
    V_symbols = value_numberer.compute_symbols()
    total_unique_symbols = value_numberer.symbol_count

    # Array to store the scalar subexpression in for each symbol
    W = numpy.empty(total_unique_symbols, dtype=object)

    # Iterate over each graph node in order
    for i, v in G.nodes.items():
        expr = v['expression']
        # Find symbols of v components
        vs = V_symbols[i]

        # Skip if there's nothing new here (should be the case for indexing types)
        # New symbols are not given to indexing types, so W[symbol] already equals
        # an expression, since it was assigned to the symbol in a previous loop
        # cycle
        if all(W[s] is not None for s in vs):
            continue

        if is_modified_terminal(expr):
            sh = expr.ufl_shape
            if sh:
                # Store each terminal expression component. We may not
                # actually need all of these later, but that will be
                # optimized away.
                # Note: symmetries will be dealt with in the value numbering.
                ws = [expr[c] for c in ufl.permutation.compute_indices(sh)]
            else:
                # Store single modified terminal expression component
                if len(vs) != 1:
                    raise RuntimeError(
                        "Expecting single symbol for scalar valued modified terminal."
                    )
                ws = [expr]
            # FIXME: Replace ws[:] with 0's if its table is empty
            # Possible redesign: loop over modified terminals only first,
            # then build tables for them, set W[s] = 0.0 for modified terminals with zero table,
            # then loop over non-(modified terminal)s to reconstruct expression.
        else:
            # Find symbols of operands
            sops = []
            for j, vop in enumerate(expr.ufl_operands):
                if isinstance(vop, ufl.classes.MultiIndex):
                    # TODO: Store MultiIndex in G.V and allocate a symbol to it for this to work
                    if not isinstance(expr, ufl.classes.IndexSum):
                        raise RuntimeError("Not expecting a %s." % type(expr))
                    sops.append(())
                else:
                    # TODO: Build edge datastructure and use instead?
                    # k = G.E[i][j]
                    k = G.e2i[vop]
                    sops.append(V_symbols[k])

            # Fetch reconstructed operand expressions
            wops = [tuple(W[k] for k in so) for so in sops]

            # Reconstruct scalar subexpressions of v
            ws = reconstruct(expr, wops)

            # Store all scalar subexpressions for v symbols
            if len(vs) != len(ws):
                raise RuntimeError("Expecting one symbol for each expression.")

        # Store each new scalar subexpression in W at the index of its symbol
        handled = set()
        for s, w in zip(vs, ws):
            if W[s] is None:
                W[s] = w
                handled.add(s)
            else:
                assert s in handled  # Result of symmetry! - but I think this never gets reached anyway (CNR)

    # Find symbols of final v from input graph
    vs = V_symbols[-1]
    scalar_expressions = W[vs]
    return scalar_expressions
Example #5
0
def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_shape,
                        p, visualise):
    # The intermediate representation dict we're building and returning
    # here
    ir = {}

    # Pass on parameters for consumption in code generation
    ir["params"] = p

    # Shared unique tables for all quadrature loops
    ir["unique_tables"] = {}
    ir["unique_table_types"] = {}

    ir["integrand"] = {}

    ir["table_dofmaps"] = {}

    for quadrature_rule, integrand in integrands.items():

        expression = integrand

        # Rebalance order of nested terminal modifiers
        expression = balance_modifiers(expression)

        # Remove QuadratureWeight terminals from expression and replace with 1.0
        expression = replace_quadratureweight(expression)

        # Build initial scalar list-based graph representation
        S = build_scalar_graph(expression)

        # Build terminal_data from V here before factorization. Then we
        # can use it to derive table properties for all modified
        # terminals, and then use that to rebuild the scalar graph more
        # efficiently before argument factorization. We can build
        # terminal_data again after factorization if that's necessary.

        initial_terminals = {i: analyse_modified_terminal(v['expression'])
                             for i, v in S.nodes.items()
                             if is_modified_terminal(v['expression'])}

        mt_unique_table_reference = build_optimized_tables(
            quadrature_rule,
            cell,
            integral_type,
            entitytype,
            initial_terminals.values(),
            ir["unique_tables"],
            rtol=p["table_rtol"],
            atol=p["table_atol"])
        unique_tables = {v.name: v.values for v in mt_unique_table_reference.values()}
        unique_table_types = {v.name: v.ttype for v in mt_unique_table_reference.values()}

        S_targets = [i for i, v in S.nodes.items() if v.get('target', False)]

        if 'zeros' in unique_table_types.values() and len(S_targets) == 1:
            # If there are any 'zero' tables, replace symbolically and rebuild graph
            #
            # TODO: Implement zero table elimination for non-scalar graphs
            for i, mt in initial_terminals.items():
                # Set modified terminals with zero tables to zero
                tr = mt_unique_table_reference.get(mt)
                if tr is not None and tr.ttype == "zeros":
                    S.nodes[i]['expression'] = ufl.as_ufl(0.0)

            # Propagate expression changes using dependency list
            for i, v in S.nodes.items():
                deps = [S.nodes[j]['expression'] for j in S.out_edges[i]]
                if deps:
                    v['expression'] = v['expression']._ufl_expr_reconstruct_(*deps)

            # Rebuild scalar target expressions and graph (this may be
            # overkill and possible to optimize away if it turns out to be
            # costly)
            expression = S.nodes[S_targets[0]]['expression']

            # Rebuild scalar list-based graph representation
            S = build_scalar_graph(expression)

        # Output diagnostic graph as pdf
        if visualise:
            visualise_graph(S, 'S.pdf')

        # Compute factorization of arguments
        rank = len(argument_shape)
        F = compute_argument_factorization(S, rank)

        # Get the 'target' nodes that are factors of arguments, and insert in dict
        FV_targets = [i for i, v in F.nodes.items() if v.get('target', False)]
        argument_factorization = {}

        for fi in FV_targets:
            # Number of blocks using this factor must agree with number of components
            # to which this factor contributes. I.e. there are more blocks iff there are more
            # components
            assert len(F.nodes[fi]['target']) == len(F.nodes[fi]['component'])

            k = 0
            for w in F.nodes[fi]['target']:
                comp = F.nodes[fi]['component'][k]
                argument_factorization[w] = argument_factorization.get(w, [])

                # Store tuple of (factor index, component index)
                argument_factorization[w].append((fi, comp))
                k += 1

        # Get list of indices in F which are the arguments (should be at start)
        argkeys = set()
        for w in argument_factorization:
            argkeys = argkeys | set(w)
        argkeys = list(argkeys)

        # Build set of modified_terminals for each mt factorized vertex in F
        # and attach tables, if appropriate
        for i, v in F.nodes.items():
            expr = v['expression']
            if is_modified_terminal(expr):
                mt = analyse_modified_terminal(expr)
                F.nodes[i]['mt'] = mt
                tr = mt_unique_table_reference.get(mt)
                if tr is not None:
                    F.nodes[i]['tr'] = tr

        # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying'
        analyse_dependencies(F, mt_unique_table_reference)

        # Output diagnostic graph as pdf
        if visualise:
            visualise_graph(F, 'F.pdf')

        # Loop over factorization terms
        block_contributions = collections.defaultdict(list)
        for ma_indices, fi_ci in sorted(argument_factorization.items()):
            # Get a bunch of information about this term
            assert rank == len(ma_indices)
            trs = tuple(F.nodes[ai]['tr'] for ai in ma_indices)

            unames = tuple(tr.name for tr in trs)
            ttypes = tuple(tr.ttype for tr in trs)
            assert not any(tt == "zeros" for tt in ttypes)

            blockmap = []
            for tr in trs:
                begin = tr.offset
                num_dofs = tr.values.shape[3]
                dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs))
                blockmap.append(dofmap)

            blockmap = tuple(blockmap)
            block_is_uniform = all(tr.is_uniform for tr in trs)

            # Collect relevant restrictions to identify blocks correctly
            # in interior facet integrals
            block_restrictions = []
            for i, ai in enumerate(ma_indices):
                if trs[i].is_uniform:
                    r = None
                else:
                    r = F.nodes[ai]['mt'].restriction

                block_restrictions.append(r)
            block_restrictions = tuple(block_restrictions)

            # Check if each *each* factor corresponding to this argument is piecewise
            all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == 'piecewise' for ifi in fi_ci)
            block_is_permuted = False
            for n in unames:
                if unique_tables[n].shape[0] > 1:
                    block_is_permuted = True
            ma_data = []
            for i, ma in enumerate(ma_indices):
                ma_data.append(ma_data_t(ma, trs[i]))

            block_is_transposed = False  # FIXME: Handle transposes for these block types

            block_unames = unames
            blockdata = block_data_t(ttypes, fi_ci,
                                     all_factors_piecewise, block_unames,
                                     block_restrictions, block_is_transposed,
                                     block_is_uniform, None, tuple(ma_data), None, block_is_permuted)

            # Insert in expr_ir for this quadrature loop
            block_contributions[blockmap].append(blockdata)

        # Figure out which table names are referenced
        active_table_names = set()
        for i, v in F.nodes.items():
            tr = v.get('tr')
            if tr is not None and F.nodes[i]['status'] != 'inactive':
                active_table_names.add(tr.name)

        # Figure out which table names are referenced in blocks
        for blockmap, contributions in itertools.chain(
                block_contributions.items()):
            for blockdata in contributions:
                for mad in blockdata.ma_data:
                    active_table_names.add(mad.tabledata.name)

        # Record all table types before dropping tables
        ir["unique_table_types"].update(unique_table_types)

        # Drop tables not referenced from modified terminals
        # and tables of zeros and ones
        unused_ttypes = ("zeros", "ones")
        keep_table_names = set()
        for name in active_table_names:
            ttype = ir["unique_table_types"][name]
            if ttype not in unused_ttypes:
                if name in unique_tables:
                    keep_table_names.add(name)
        unique_tables = {name: unique_tables[name] for name in keep_table_names}

        # Add to global set of all tables
        for name, table in unique_tables.items():
            tbl = ir["unique_tables"].get(name)
            if tbl is not None and not numpy.allclose(
                    tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]):
                raise RuntimeError("Table values mismatch with same name.")

        ir["unique_tables"] = unique_tables

        # Build IR dict for the given expressions
        # Store final ir for this num_points
        ir["integrand"][quadrature_rule] = {"factorization": F,
                                            "modified_arguments": [F.nodes[i]['mt'] for i in argkeys],
                                            "block_contributions": block_contributions}

        restrictions = [i.restriction for i in initial_terminals.values()]
        ir["needs_facet_permutations"] = "+" in restrictions and "-" in restrictions

    return ir