def getops(e): """Get a modifiable list of operands of e, optionally treating modified terminals as a unit.""" # TODO: Maybe use e._ufl_is_terminal_modifier_ if e._ufl_is_terminal_ or (skip_terminal_modifiers and is_modified_terminal(e)): return [] else: return list(e.ufl_operands)
def replace_quadratureweight(expression): """Remove any QuadratureWeight terminals and replace with 1.0.""" r = [] for node in ufl.corealg.traversal.unique_pre_traversal(expression): if is_modified_terminal(node) and isinstance(node, QuadratureWeight): r.append(node) replace_map = {q: 1.0 for q in r} return ufl.algorithms.replace(expression, replace_map)
def _find_terminals_in_ufl_expression(e, etype): """Recursively search expression for terminals of type etype.""" r = [] for op in e.ufl_operands: if is_modified_terminal(op) and isinstance(op, etype): r.append(op) else: r += _find_terminals_in_ufl_expression(op, etype) return r
def rebuild_with_scalar_subexpressions(G): """Build a new expression2index mapping where each subexpression is scalar valued. Input: - G.e2i - G.V - G.V_symbols - G.total_unique_symbols Output: - NV - Array with reverse mapping from index to expression - nvs - Tuple of ne2i indices corresponding to the last vertex of G.V """ # Compute symbols over graph and rebuild scalar expression # # New expression which represents usually an algebraic operation # generates a new symbol value_numberer = ValueNumberer(G) # V_symbols maps an index of a node to a list of # symbols which are present in that node V_symbols = value_numberer.compute_symbols() total_unique_symbols = value_numberer.symbol_count # Array to store the scalar subexpression in for each symbol W = numpy.empty(total_unique_symbols, dtype=object) # Iterate over each graph node in order for i, v in G.nodes.items(): expr = v['expression'] # Find symbols of v components vs = V_symbols[i] # Skip if there's nothing new here (should be the case for indexing types) # New symbols are not given to indexing types, so W[symbol] already equals # an expression, since it was assigned to the symbol in a previous loop # cycle if all(W[s] is not None for s in vs): continue if is_modified_terminal(expr): sh = expr.ufl_shape if sh: # Store each terminal expression component. We may not # actually need all of these later, but that will be # optimized away. # Note: symmetries will be dealt with in the value numbering. ws = [expr[c] for c in ufl.permutation.compute_indices(sh)] else: # Store single modified terminal expression component if len(vs) != 1: raise RuntimeError( "Expecting single symbol for scalar valued modified terminal." ) ws = [expr] # FIXME: Replace ws[:] with 0's if its table is empty # Possible redesign: loop over modified terminals only first, # then build tables for them, set W[s] = 0.0 for modified terminals with zero table, # then loop over non-(modified terminal)s to reconstruct expression. else: # Find symbols of operands sops = [] for j, vop in enumerate(expr.ufl_operands): if isinstance(vop, ufl.classes.MultiIndex): # TODO: Store MultiIndex in G.V and allocate a symbol to it for this to work if not isinstance(expr, ufl.classes.IndexSum): raise RuntimeError("Not expecting a %s." % type(expr)) sops.append(()) else: # TODO: Build edge datastructure and use instead? # k = G.E[i][j] k = G.e2i[vop] sops.append(V_symbols[k]) # Fetch reconstructed operand expressions wops = [tuple(W[k] for k in so) for so in sops] # Reconstruct scalar subexpressions of v ws = reconstruct(expr, wops) # Store all scalar subexpressions for v symbols if len(vs) != len(ws): raise RuntimeError("Expecting one symbol for each expression.") # Store each new scalar subexpression in W at the index of its symbol handled = set() for s, w in zip(vs, ws): if W[s] is None: W[s] = w handled.add(s) else: assert s in handled # Result of symmetry! - but I think this never gets reached anyway (CNR) # Find symbols of final v from input graph vs = V_symbols[-1] scalar_expressions = W[vs] return scalar_expressions
def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_shape, p, visualise): # The intermediate representation dict we're building and returning # here ir = {} # Pass on parameters for consumption in code generation ir["params"] = p # Shared unique tables for all quadrature loops ir["unique_tables"] = {} ir["unique_table_types"] = {} ir["integrand"] = {} ir["table_dofmaps"] = {} for quadrature_rule, integrand in integrands.items(): expression = integrand # Rebalance order of nested terminal modifiers expression = balance_modifiers(expression) # Remove QuadratureWeight terminals from expression and replace with 1.0 expression = replace_quadratureweight(expression) # Build initial scalar list-based graph representation S = build_scalar_graph(expression) # Build terminal_data from V here before factorization. Then we # can use it to derive table properties for all modified # terminals, and then use that to rebuild the scalar graph more # efficiently before argument factorization. We can build # terminal_data again after factorization if that's necessary. initial_terminals = {i: analyse_modified_terminal(v['expression']) for i, v in S.nodes.items() if is_modified_terminal(v['expression'])} mt_unique_table_reference = build_optimized_tables( quadrature_rule, cell, integral_type, entitytype, initial_terminals.values(), ir["unique_tables"], rtol=p["table_rtol"], atol=p["table_atol"]) unique_tables = {v.name: v.values for v in mt_unique_table_reference.values()} unique_table_types = {v.name: v.ttype for v in mt_unique_table_reference.values()} S_targets = [i for i, v in S.nodes.items() if v.get('target', False)] if 'zeros' in unique_table_types.values() and len(S_targets) == 1: # If there are any 'zero' tables, replace symbolically and rebuild graph # # TODO: Implement zero table elimination for non-scalar graphs for i, mt in initial_terminals.items(): # Set modified terminals with zero tables to zero tr = mt_unique_table_reference.get(mt) if tr is not None and tr.ttype == "zeros": S.nodes[i]['expression'] = ufl.as_ufl(0.0) # Propagate expression changes using dependency list for i, v in S.nodes.items(): deps = [S.nodes[j]['expression'] for j in S.out_edges[i]] if deps: v['expression'] = v['expression']._ufl_expr_reconstruct_(*deps) # Rebuild scalar target expressions and graph (this may be # overkill and possible to optimize away if it turns out to be # costly) expression = S.nodes[S_targets[0]]['expression'] # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) # Output diagnostic graph as pdf if visualise: visualise_graph(S, 'S.pdf') # Compute factorization of arguments rank = len(argument_shape) F = compute_argument_factorization(S, rank) # Get the 'target' nodes that are factors of arguments, and insert in dict FV_targets = [i for i, v in F.nodes.items() if v.get('target', False)] argument_factorization = {} for fi in FV_targets: # Number of blocks using this factor must agree with number of components # to which this factor contributes. I.e. there are more blocks iff there are more # components assert len(F.nodes[fi]['target']) == len(F.nodes[fi]['component']) k = 0 for w in F.nodes[fi]['target']: comp = F.nodes[fi]['component'][k] argument_factorization[w] = argument_factorization.get(w, []) # Store tuple of (factor index, component index) argument_factorization[w].append((fi, comp)) k += 1 # Get list of indices in F which are the arguments (should be at start) argkeys = set() for w in argument_factorization: argkeys = argkeys | set(w) argkeys = list(argkeys) # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate for i, v in F.nodes.items(): expr = v['expression'] if is_modified_terminal(expr): mt = analyse_modified_terminal(expr) F.nodes[i]['mt'] = mt tr = mt_unique_table_reference.get(mt) if tr is not None: F.nodes[i]['tr'] = tr # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' analyse_dependencies(F, mt_unique_table_reference) # Output diagnostic graph as pdf if visualise: visualise_graph(F, 'F.pdf') # Loop over factorization terms block_contributions = collections.defaultdict(list) for ma_indices, fi_ci in sorted(argument_factorization.items()): # Get a bunch of information about this term assert rank == len(ma_indices) trs = tuple(F.nodes[ai]['tr'] for ai in ma_indices) unames = tuple(tr.name for tr in trs) ttypes = tuple(tr.ttype for tr in trs) assert not any(tt == "zeros" for tt in ttypes) blockmap = [] for tr in trs: begin = tr.offset num_dofs = tr.values.shape[3] dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) blockmap.append(dofmap) blockmap = tuple(blockmap) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals block_restrictions = [] for i, ai in enumerate(ma_indices): if trs[i].is_uniform: r = None else: r = F.nodes[ai]['mt'].restriction block_restrictions.append(r) block_restrictions = tuple(block_restrictions) # Check if each *each* factor corresponding to this argument is piecewise all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == 'piecewise' for ifi in fi_ci) block_is_permuted = False for n in unames: if unique_tables[n].shape[0] > 1: block_is_permuted = True ma_data = [] for i, ma in enumerate(ma_indices): ma_data.append(ma_data_t(ma, trs[i])) block_is_transposed = False # FIXME: Handle transposes for these block types block_unames = unames blockdata = block_data_t(ttypes, fi_ci, all_factors_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, None, tuple(ma_data), None, block_is_permuted) # Insert in expr_ir for this quadrature loop block_contributions[blockmap].append(blockdata) # Figure out which table names are referenced active_table_names = set() for i, v in F.nodes.items(): tr = v.get('tr') if tr is not None and F.nodes[i]['status'] != 'inactive': active_table_names.add(tr.name) # Figure out which table names are referenced in blocks for blockmap, contributions in itertools.chain( block_contributions.items()): for blockdata in contributions: for mad in blockdata.ma_data: active_table_names.add(mad.tabledata.name) # Record all table types before dropping tables ir["unique_table_types"].update(unique_table_types) # Drop tables not referenced from modified terminals # and tables of zeros and ones unused_ttypes = ("zeros", "ones") keep_table_names = set() for name in active_table_names: ttype = ir["unique_table_types"][name] if ttype not in unused_ttypes: if name in unique_tables: keep_table_names.add(name) unique_tables = {name: unique_tables[name] for name in keep_table_names} # Add to global set of all tables for name, table in unique_tables.items(): tbl = ir["unique_tables"].get(name) if tbl is not None and not numpy.allclose( tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]): raise RuntimeError("Table values mismatch with same name.") ir["unique_tables"] = unique_tables # Build IR dict for the given expressions # Store final ir for this num_points ir["integrand"][quadrature_rule] = {"factorization": F, "modified_arguments": [F.nodes[i]['mt'] for i in argkeys], "block_contributions": block_contributions} restrictions = [i.restriction for i in initial_terminals.values()] ir["needs_facet_permutations"] = "+" in restrictions and "-" in restrictions return ir