def arg_ordering_key(i): """Return a key for sorting argument vertex indices. Key is based on the properties of the modified terminal.""" mt = analyse_modified_terminal(S.nodes[i]['expression']) return mt.argument_ordering_key()
def _modified_terminal(self, v): """Handle modified terminal. Modifiers: --------- terminal - the underlying Terminal object global_derivatives - tuple of ints, each meaning derivative in that global direction local_derivatives - tuple of ints, each meaning derivative in that local direction reference_value - bool, whether this is represented in reference frame averaged - None, 'facet' or 'cell' restriction - None, '+' or '-' component - tuple of ints, the global component of the Terminal flat_component - single int, flattened local component of the Terminal, considering symmetry """ # (1) mt.terminal.ufl_shape defines a core indexing space UNLESS mt.reference_value, # in which case the reference value shape of the element must be used. # (2) mt.terminal.ufl_element().symmetry() defines core symmetries # (3) averaging and restrictions define distinct symbols, no additional symmetries # (4) two or more grad/reference_grad defines distinct symbols with additional symmetries # v is not necessary scalar here, indexing in (0,...,0) picks the first scalar component # to analyse, which should be sufficient to get the base shape and derivatives if v.ufl_shape: mt = analyse_modified_terminal(v[(0, ) * len(v.ufl_shape)]) else: mt = analyse_modified_terminal(v) # Get derivatives num_ld = len(mt.local_derivatives) num_gd = len(mt.global_derivatives) assert not (num_ld and num_gd) if num_ld: domain = mt.terminal.ufl_domain() tdim = domain.topological_dimension() d_components = ufl.permutation.compute_indices((tdim, ) * num_ld) elif num_gd: domain = mt.terminal.ufl_domain() gdim = domain.geometric_dimension() d_components = ufl.permutation.compute_indices((gdim, ) * num_gd) else: d_components = [()] # Get base shape without the derivative axes base_components = ufl.permutation.compute_indices(mt.base_shape) # Build symbols with symmetric components and derivatives skipped symbols = [] mapped_symbols = {} for bc in base_components: for dc in d_components: # Build mapped component mc with symmetries from element # and derivatives combined mbc = mt.base_symmetry.get(bc, bc) mdc = tuple(sorted(dc)) mc = mbc + mdc # Get existing symbol or create new and store with # mapped component mc as key s = mapped_symbols.get(mc) if s is None: s = self.new_symbol() mapped_symbols[mc] = s symbols.append(s) # Consistency check before returning symbols assert not v.ufl_free_indices if ufl.product(v.ufl_shape) != len(symbols): raise RuntimeError("Internal error in value numbering.") return symbols
def build_uflacs_ir(cell, integral_type, entitytype, integrands, tensor_shape, quadrature_rules, parameters): # The intermediate representation dict we're building and returning # here ir = {} # Extract uflacs specific optimization and code generation # parameters p = parse_uflacs_optimization_parameters(parameters, integral_type) # Pass on parameters for consumption in code generation ir["params"] = p # Shared unique tables for all quadrature loops ir["unique_tables"] = {} ir["unique_table_types"] = {} # Shared piecewise expr_ir for all quadrature loops ir["piecewise_ir"] = {"factorization": None, "modified_arguments": [], "preintegrated_blocks": {}, "premultiplied_blocks": {}, "preintegrated_contributions": collections.defaultdict(list), "block_contributions": collections.defaultdict(list)} # { num_points: expr_ir for one integrand } ir["varying_irs"] = {"factorization": None} # Whether we expect the quadrature weight to be applied or not (in # some cases it's just set to 1 in ufl integral scaling) tdim = cell.topological_dimension() expect_weight = (integral_type not in point_integral_types and (entitytype == "cell" or ( entitytype == "facet" and tdim > 1) or (integral_type in custom_integral_types))) # Analyse each num_points/integrand separately assert isinstance(integrands, dict) all_num_points = sorted(integrands.keys()) cases = [(num_points, [integrands[num_points]]) for num_points in all_num_points] ir["all_num_points"] = all_num_points for num_points, expressions in cases: assert len(expressions) == 1 expression = expressions[0] # Rebalance order of nested terminal modifiers expression = balance_modifiers(expression) # Remove QuadratureWeight terminals from expression and replace with 1.0 expression = replace_quadratureweight(expression) # Build initial scalar list-based graph representation S = build_scalar_graph(expression) S_targets = [i for i, v in S.nodes.items() if v.get('target', False)] assert len(S_targets) == 1 S_target = S_targets[0] # Build terminal_data from V here before factorization. Then we # can use it to derive table properties for all modified # terminals, and then use that to rebuild the scalar graph more # efficiently before argument factorization. We can build # terminal_data again after factorization if that's necessary. initial_terminals = {i: analyse_modified_terminal(v['expression']) for i, v in S.nodes.items() if is_modified_terminal(v['expression'])} unique_tables, unique_table_types, unique_table_num_dofs, mt_unique_table_reference = build_optimized_tables( num_points, quadrature_rules, cell, integral_type, entitytype, initial_terminals.values(), ir["unique_tables"], p["enable_table_zero_compression"], rtol=p["table_rtol"], atol=p["table_atol"]) # If there are any 'zero' tables, replace symbolically and rebuild graph if 'zeros' in unique_table_types.values(): for i, mt in initial_terminals.items(): # Set modified terminals with zero tables to zero tr = mt_unique_table_reference.get(mt) if tr is not None and tr.ttype == "zeros": S.nodes[i]['expression'] = ufl.as_ufl(0.0) # Propagate expression changes using dependency list for i, v in S.nodes.items(): deps = [S.nodes[j]['expression'] for j in S.out_edges[i]] if deps: v['expression'] = v['expression']._ufl_expr_reconstruct_(*deps) # Rebuild scalar target expressions and graph (this may be # overkill and possible to optimize away if it turns out to be # costly) expression = S.nodes[S_target]['expression'] # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) # Output diagnostic graph as pdf if parameters['visualise']: visualise(S, 'S.pdf') # Compute factorization of arguments rank = len(tensor_shape) F = compute_argument_factorization(S, rank) # Get the 'target' nodes that are factors of arguments, and insert in dict FV_targets = [i for i, v in F.nodes.items() if v.get('target', False)] argument_factorization = {} for i in FV_targets: for w in F.nodes[i]['target']: argument_factorization[w] = i # Get list of indices in F which are the arguments (should be at start) argkeys = set() for w in argument_factorization: argkeys = argkeys | set(w) argkeys = list(argkeys) # Output diagnostic graph as pdf if parameters['visualise']: visualise(F, 'F.pdf') # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate for i, v in F.nodes.items(): expr = v['expression'] if is_modified_terminal(expr): mt = analyse_modified_terminal(expr) F.nodes[i]['mt'] = mt tr = mt_unique_table_reference.get(mt) if tr is not None: F.nodes[i]['tr'] = tr # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' analyse_dependencies(F, mt_unique_table_reference) # Save the factorisation graph to the piecewise IR ir["piecewise_ir"]["factorization"] = F ir["piecewise_ir"]["modified_arguments"] = [F.nodes[i]['mt'] for i in argkeys] # Loop over factorization terms block_contributions = collections.defaultdict(list) for ma_indices, fi in sorted(argument_factorization.items()): # Get a bunch of information about this term assert rank == len(ma_indices) trs = tuple(F.nodes[ai]['tr'] for ai in ma_indices) unames = tuple(tr.name for tr in trs) ttypes = tuple(tr.ttype for tr in trs) assert not any(tt == "zeros" for tt in ttypes) blockmap = tuple(tr.dofmap for tr in trs) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals block_restrictions = [] for i, ai in enumerate(ma_indices): if trs[i].is_uniform: r = None else: r = F.nodes[ai]['mt'].restriction block_restrictions.append(r) block_restrictions = tuple(block_restrictions) factor_is_piecewise = F.nodes[fi]['status'] == 'piecewise' # TODO: Add separate block modes for quadrature # Both arguments in quadrature elements """ for iq fw = f*w #for i # for j # B[i,j] = fw*U[i]*V[j] = 0 if i != iq or j != iq BQ[iq] = B[iq,iq] = fw for (iq) A[iq+offset0, iq+offset1] = BQ[iq] """ # One argument in quadrature element """ for iq fw[iq] = f*w #for i # for j # B[i,j] = fw*UQ[i]*V[j] = 0 if i != iq for j BQ[iq,j] = fw[iq]*V[iq,j] for (iq) for (j) A[iq+offset, j+offset] = BQ[iq,j] """ # Decide how to handle code generation for this block if p["enable_preintegration"] and (factor_is_piecewise and rank > 0 and "quadrature" not in ttypes): # - Piecewise factor is an absolute prerequisite # - Could work for rank 0 as well but currently doesn't # - Haven't considered how quadrature elements work out block_mode = "preintegrated" elif p["enable_premultiplication"] and (rank > 0 and all(tt in piecewise_ttypes for tt in ttypes)): # Integrate functional in quadloop, scale block after # quadloop block_mode = "premultiplied" elif p["enable_sum_factorization"]: if (rank == 2 and any(tt in piecewise_ttypes for tt in ttypes)): # Partial computation in quadloop of f*u[i], compute # (f*u[i])*v[i] outside quadloop, (or with u,v # swapped) block_mode = "partial" else: # Full runtime integration of f*u[i]*v[j], can still # do partial computation in quadloop of f*u[i] but # must compute (f*u[i])*v[i] as well inside # quadloop. (or with u,v swapped) block_mode = "full" else: # Use full runtime integration with nothing fancy going # on block_mode = "safe" # Carry out decision if block_mode == "preintegrated": # Add to contributions: # P = sum_q weight*u*v; preintegrated here # B[...] = f * P[...]; generated after quadloop # A[blockmap] += B[...]; generated after quadloop cache = ir["piecewise_ir"]["preintegrated_blocks"] block_is_transposed = False pname = cache.get(unames) # Reuse transpose to save memory if p["enable_block_transpose_reuse"] and pname is None and len(unames) == 2: pname = cache.get((unames[1], unames[0])) if pname is not None: # Cache hit on transpose block_is_transposed = True if pname is None: # Cache miss, precompute block weights = quadrature_rules[num_points][1] if integral_type == "interior_facet": ptable = integrate_block_interior_facets( weights, unames, ttypes, unique_tables, unique_table_num_dofs) else: ptable = integrate_block(weights, unames, ttypes, unique_tables, unique_table_num_dofs) ptable = clamp_table_small_numbers( ptable, rtol=p["table_rtol"], atol=p["table_atol"]) pname = "PI%d" % (len(cache, )) cache[unames] = pname unique_tables[pname] = ptable unique_table_types[pname] = "preintegrated" assert factor_is_piecewise block_unames = (pname, ) blockdata = block_data_t( block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, pname, None, None) block_is_piecewise = True elif block_mode == "premultiplied": # Add to contributions: # P = u*v; computed here # FI = sum_q weight * f; generated inside quadloop # B[...] = FI * P[...]; generated after quadloop # A[blockmap] += B[...]; generated after quadloop cache = ir["piecewise_ir"]["premultiplied_blocks"] block_is_transposed = False pname = cache.get(unames) # Reuse transpose to save memory if p["enable_block_transpose_reuse"] and pname is None and len(unames) == 2: pname = cache.get((unames[1], unames[0])) if pname is not None: # Cache hit on transpose block_is_transposed = True if pname is None: # Cache miss, precompute block if integral_type == "interior_facet": ptable = multiply_block_interior_facets(0, unames, ttypes, unique_tables, unique_table_num_dofs) else: ptable = multiply_block(0, unames, ttypes, unique_tables, unique_table_num_dofs) pname = "PM%d" % (len(cache, )) cache[unames] = pname unique_tables[pname] = ptable unique_table_types[pname] = "premultiplied" block_unames = (pname, ) blockdata = block_data_t( block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, pname, None, None) block_is_piecewise = False # elif block_mode == "scaled": # # TODO: Add mode, block is piecewise but choose not to be premultiplied # # Add to contributions: # # FI = sum_q weight * f; generated inside quadloop # # B[...] = FI * u * v; generated after quadloop # # A[blockmap] += B[...]; generated after quadloop # raise NotImplementedError("scaled block mode not implemented.") # # (probably need mostly the same data as # # premultiplied, except no P table name or values) # block_is_piecewise = False elif block_mode in ("partial", "full", "safe"): block_is_piecewise = factor_is_piecewise and not expect_weight ma_data = [] for i, ma in enumerate(ma_indices): if not trs[i].is_piecewise: block_is_piecewise = False ma_data.append(ma_data_t(ma, trs[i])) block_is_transposed = False # FIXME: Handle transposes for these block types if block_mode == "partial": # Add to contributions: # P[i] = sum_q weight * f * u[i]; generated inside quadloop # B[i,j] = P[i] * v[j]; generated after quadloop (where v is the piecewise ma) # A[blockmap] += B[...]; generated after quadloop # Find first piecewise index TODO: Is last better? just reverse range here for i in range(rank): if trs[i].is_piecewise: piecewise_ma_index = i break assert rank == 2 not_piecewise_ma_index = 1 - piecewise_ma_index block_unames = (unames[not_piecewise_ma_index], ) blockdata = block_data_t(block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, None, None, tuple(ma_data), piecewise_ma_index) elif block_mode in ("full", "safe"): # Add to contributions: # B[i] = sum_q weight * f * u[i] * v[j]; generated inside quadloop # A[blockmap] += B[i]; generated after quadloop block_unames = unames blockdata = block_data_t(block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, None, None, tuple(ma_data), None) else: raise RuntimeError("Invalid block_mode %s" % (block_mode, )) if block_is_piecewise: # Insert in piecewise expr_ir ir["piecewise_ir"]["block_contributions"][blockmap].append(blockdata) else: # Insert in varying expr_ir for this quadrature loop block_contributions[blockmap].append(blockdata) # Figure out which table names are referenced in unstructured # partition active_table_names = set() for i, v in F.nodes.items(): tr = v.get('tr') if tr is not None and F.nodes[i]['status'] != 'inactive': active_table_names.add(tr.name) # Figure out which table names are referenced in blocks for blockmap, contributions in itertools.chain( block_contributions.items(), ir["piecewise_ir"]["block_contributions"].items()): for blockdata in contributions: if blockdata.block_mode in ("preintegrated", "premultiplied"): active_table_names.add(blockdata.name) elif blockdata.block_mode in ("partial", "full", "safe"): for mad in blockdata.ma_data: active_table_names.add(mad.tabledata.name) # Record all table types before dropping tables ir["unique_table_types"].update(unique_table_types) # Drop tables not referenced from modified terminals # and tables of zeros and ones unused_ttypes = ("zeros", "ones", "quadrature") keep_table_names = set() for name in active_table_names: ttype = ir["unique_table_types"][name] if ttype not in unused_ttypes: if name in unique_tables: keep_table_names.add(name) unique_tables = {name: unique_tables[name] for name in keep_table_names} # Add to global set of all tables for name, table in unique_tables.items(): tbl = ir["unique_tables"].get(name) if tbl is not None and not numpy.allclose( tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]): raise RuntimeError("Table values mismatch with same name.") ir["unique_tables"].update(unique_tables) # Analyse active terminals to check what we'll need to generate code for active_mts = [] for i, v in F.nodes.items(): mt = v.get('mt', False) if mt and F.nodes[i]['status'] != 'inactive': active_mts.append(mt) # Figure out if we need to access CellCoordinate to avoid # generating quadrature point table otherwise if integral_type == "cell": need_points = any(isinstance(mt.terminal, CellCoordinate) for mt in active_mts) elif integral_type in facet_integral_types: need_points = any(isinstance(mt.terminal, FacetCoordinate) for mt in active_mts) elif integral_type in custom_integral_types: need_points = True # TODO: Always? else: need_points = False # Figure out if we need to access QuadratureWeight to avoid # generating quadrature point table otherwise need_weights = # any(isinstance(mt.terminal, QuadratureWeight) for mt in # active_mts) # Count blocks of each mode block_modes = collections.defaultdict(int) for blockmap, contributions in block_contributions.items(): for blockdata in contributions: block_modes[blockdata.block_mode] += 1 # Debug output summary = "\n".join( " {}\t{}".format(count, mode) for mode, count in sorted(block_modes.items())) logger.debug("Blocks of each mode: {}".format(summary)) # If there are any blocks other than preintegrated we need weights if expect_weight and any(mode != "preintegrated" for mode in block_modes): need_weights = True elif integral_type in custom_integral_types: need_weights = True # TODO: Always? else: need_weights = False # Build IR dict for the given expressions # Store final ir for this num_points ir["varying_irs"][num_points] = {"factorization": F, "modified_arguments": [F.nodes[i]['mt'] for i in argkeys], "block_contributions": block_contributions, "need_points": need_points, "need_weights": need_weights} return ir