def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_shape, p, visualise): # The intermediate representation dict we're building and returning # here ir = {} # Pass on parameters for consumption in code generation ir["params"] = p # Shared unique tables for all quadrature loops ir["unique_tables"] = {} ir["unique_table_types"] = {} ir["integrand"] = {} ir["table_dofmaps"] = {} for quadrature_rule, integrand in integrands.items(): expression = integrand # Rebalance order of nested terminal modifiers expression = balance_modifiers(expression) # Remove QuadratureWeight terminals from expression and replace with 1.0 expression = replace_quadratureweight(expression) # Build initial scalar list-based graph representation S = build_scalar_graph(expression) # Build terminal_data from V here before factorization. Then we # can use it to derive table properties for all modified # terminals, and then use that to rebuild the scalar graph more # efficiently before argument factorization. We can build # terminal_data again after factorization if that's necessary. initial_terminals = {i: analyse_modified_terminal(v['expression']) for i, v in S.nodes.items() if is_modified_terminal(v['expression'])} mt_unique_table_reference = build_optimized_tables( quadrature_rule, cell, integral_type, entitytype, initial_terminals.values(), ir["unique_tables"], rtol=p["table_rtol"], atol=p["table_atol"]) unique_tables = {v.name: v.values for v in mt_unique_table_reference.values()} unique_table_types = {v.name: v.ttype for v in mt_unique_table_reference.values()} S_targets = [i for i, v in S.nodes.items() if v.get('target', False)] if 'zeros' in unique_table_types.values() and len(S_targets) == 1: # If there are any 'zero' tables, replace symbolically and rebuild graph # # TODO: Implement zero table elimination for non-scalar graphs for i, mt in initial_terminals.items(): # Set modified terminals with zero tables to zero tr = mt_unique_table_reference.get(mt) if tr is not None and tr.ttype == "zeros": S.nodes[i]['expression'] = ufl.as_ufl(0.0) # Propagate expression changes using dependency list for i, v in S.nodes.items(): deps = [S.nodes[j]['expression'] for j in S.out_edges[i]] if deps: v['expression'] = v['expression']._ufl_expr_reconstruct_(*deps) # Rebuild scalar target expressions and graph (this may be # overkill and possible to optimize away if it turns out to be # costly) expression = S.nodes[S_targets[0]]['expression'] # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) # Output diagnostic graph as pdf if visualise: visualise_graph(S, 'S.pdf') # Compute factorization of arguments rank = len(argument_shape) F = compute_argument_factorization(S, rank) # Get the 'target' nodes that are factors of arguments, and insert in dict FV_targets = [i for i, v in F.nodes.items() if v.get('target', False)] argument_factorization = {} for fi in FV_targets: # Number of blocks using this factor must agree with number of components # to which this factor contributes. I.e. there are more blocks iff there are more # components assert len(F.nodes[fi]['target']) == len(F.nodes[fi]['component']) k = 0 for w in F.nodes[fi]['target']: comp = F.nodes[fi]['component'][k] argument_factorization[w] = argument_factorization.get(w, []) # Store tuple of (factor index, component index) argument_factorization[w].append((fi, comp)) k += 1 # Get list of indices in F which are the arguments (should be at start) argkeys = set() for w in argument_factorization: argkeys = argkeys | set(w) argkeys = list(argkeys) # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate for i, v in F.nodes.items(): expr = v['expression'] if is_modified_terminal(expr): mt = analyse_modified_terminal(expr) F.nodes[i]['mt'] = mt tr = mt_unique_table_reference.get(mt) if tr is not None: F.nodes[i]['tr'] = tr # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' analyse_dependencies(F, mt_unique_table_reference) # Output diagnostic graph as pdf if visualise: visualise_graph(F, 'F.pdf') # Loop over factorization terms block_contributions = collections.defaultdict(list) for ma_indices, fi_ci in sorted(argument_factorization.items()): # Get a bunch of information about this term assert rank == len(ma_indices) trs = tuple(F.nodes[ai]['tr'] for ai in ma_indices) unames = tuple(tr.name for tr in trs) ttypes = tuple(tr.ttype for tr in trs) assert not any(tt == "zeros" for tt in ttypes) blockmap = [] for tr in trs: begin = tr.offset num_dofs = tr.values.shape[3] dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) blockmap.append(dofmap) blockmap = tuple(blockmap) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals block_restrictions = [] for i, ai in enumerate(ma_indices): if trs[i].is_uniform: r = None else: r = F.nodes[ai]['mt'].restriction block_restrictions.append(r) block_restrictions = tuple(block_restrictions) # Check if each *each* factor corresponding to this argument is piecewise all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == 'piecewise' for ifi in fi_ci) block_is_permuted = False for n in unames: if unique_tables[n].shape[0] > 1: block_is_permuted = True ma_data = [] for i, ma in enumerate(ma_indices): ma_data.append(ma_data_t(ma, trs[i])) block_is_transposed = False # FIXME: Handle transposes for these block types block_unames = unames blockdata = block_data_t(ttypes, fi_ci, all_factors_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, None, tuple(ma_data), None, block_is_permuted) # Insert in expr_ir for this quadrature loop block_contributions[blockmap].append(blockdata) # Figure out which table names are referenced active_table_names = set() for i, v in F.nodes.items(): tr = v.get('tr') if tr is not None and F.nodes[i]['status'] != 'inactive': active_table_names.add(tr.name) # Figure out which table names are referenced in blocks for blockmap, contributions in itertools.chain( block_contributions.items()): for blockdata in contributions: for mad in blockdata.ma_data: active_table_names.add(mad.tabledata.name) # Record all table types before dropping tables ir["unique_table_types"].update(unique_table_types) # Drop tables not referenced from modified terminals # and tables of zeros and ones unused_ttypes = ("zeros", "ones") keep_table_names = set() for name in active_table_names: ttype = ir["unique_table_types"][name] if ttype not in unused_ttypes: if name in unique_tables: keep_table_names.add(name) unique_tables = {name: unique_tables[name] for name in keep_table_names} # Add to global set of all tables for name, table in unique_tables.items(): tbl = ir["unique_tables"].get(name) if tbl is not None and not numpy.allclose( tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]): raise RuntimeError("Table values mismatch with same name.") ir["unique_tables"] = unique_tables # Build IR dict for the given expressions # Store final ir for this num_points ir["integrand"][quadrature_rule] = {"factorization": F, "modified_arguments": [F.nodes[i]['mt'] for i in argkeys], "block_contributions": block_contributions} restrictions = [i.restriction for i in initial_terminals.values()] ir["needs_facet_permutations"] = "+" in restrictions and "-" in restrictions return ir
def build_uflacs_ir(cell, integral_type, entitytype, integrands, tensor_shape, quadrature_rules, parameters): # The intermediate representation dict we're building and returning # here ir = {} # Extract uflacs specific optimization and code generation # parameters p = parse_uflacs_optimization_parameters(parameters, integral_type) # Pass on parameters for consumption in code generation ir["params"] = p # Shared unique tables for all quadrature loops ir["unique_tables"] = {} ir["unique_table_types"] = {} # Shared piecewise expr_ir for all quadrature loops ir["piecewise_ir"] = {"factorization": None, "modified_arguments": [], "preintegrated_blocks": {}, "premultiplied_blocks": {}, "preintegrated_contributions": collections.defaultdict(list), "block_contributions": collections.defaultdict(list)} # { num_points: expr_ir for one integrand } ir["varying_irs"] = {"factorization": None} # Whether we expect the quadrature weight to be applied or not (in # some cases it's just set to 1 in ufl integral scaling) tdim = cell.topological_dimension() expect_weight = (integral_type not in point_integral_types and (entitytype == "cell" or ( entitytype == "facet" and tdim > 1) or (integral_type in custom_integral_types))) # Analyse each num_points/integrand separately assert isinstance(integrands, dict) all_num_points = sorted(integrands.keys()) cases = [(num_points, [integrands[num_points]]) for num_points in all_num_points] ir["all_num_points"] = all_num_points for num_points, expressions in cases: assert len(expressions) == 1 expression = expressions[0] # Rebalance order of nested terminal modifiers expression = balance_modifiers(expression) # Remove QuadratureWeight terminals from expression and replace with 1.0 expression = replace_quadratureweight(expression) # Build initial scalar list-based graph representation S = build_scalar_graph(expression) S_targets = [i for i, v in S.nodes.items() if v.get('target', False)] assert len(S_targets) == 1 S_target = S_targets[0] # Build terminal_data from V here before factorization. Then we # can use it to derive table properties for all modified # terminals, and then use that to rebuild the scalar graph more # efficiently before argument factorization. We can build # terminal_data again after factorization if that's necessary. initial_terminals = {i: analyse_modified_terminal(v['expression']) for i, v in S.nodes.items() if is_modified_terminal(v['expression'])} unique_tables, unique_table_types, unique_table_num_dofs, mt_unique_table_reference = build_optimized_tables( num_points, quadrature_rules, cell, integral_type, entitytype, initial_terminals.values(), ir["unique_tables"], p["enable_table_zero_compression"], rtol=p["table_rtol"], atol=p["table_atol"]) # If there are any 'zero' tables, replace symbolically and rebuild graph if 'zeros' in unique_table_types.values(): for i, mt in initial_terminals.items(): # Set modified terminals with zero tables to zero tr = mt_unique_table_reference.get(mt) if tr is not None and tr.ttype == "zeros": S.nodes[i]['expression'] = ufl.as_ufl(0.0) # Propagate expression changes using dependency list for i, v in S.nodes.items(): deps = [S.nodes[j]['expression'] for j in S.out_edges[i]] if deps: v['expression'] = v['expression']._ufl_expr_reconstruct_(*deps) # Rebuild scalar target expressions and graph (this may be # overkill and possible to optimize away if it turns out to be # costly) expression = S.nodes[S_target]['expression'] # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) # Output diagnostic graph as pdf if parameters['visualise']: visualise(S, 'S.pdf') # Compute factorization of arguments rank = len(tensor_shape) F = compute_argument_factorization(S, rank) # Get the 'target' nodes that are factors of arguments, and insert in dict FV_targets = [i for i, v in F.nodes.items() if v.get('target', False)] argument_factorization = {} for i in FV_targets: for w in F.nodes[i]['target']: argument_factorization[w] = i # Get list of indices in F which are the arguments (should be at start) argkeys = set() for w in argument_factorization: argkeys = argkeys | set(w) argkeys = list(argkeys) # Output diagnostic graph as pdf if parameters['visualise']: visualise(F, 'F.pdf') # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate for i, v in F.nodes.items(): expr = v['expression'] if is_modified_terminal(expr): mt = analyse_modified_terminal(expr) F.nodes[i]['mt'] = mt tr = mt_unique_table_reference.get(mt) if tr is not None: F.nodes[i]['tr'] = tr # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' analyse_dependencies(F, mt_unique_table_reference) # Save the factorisation graph to the piecewise IR ir["piecewise_ir"]["factorization"] = F ir["piecewise_ir"]["modified_arguments"] = [F.nodes[i]['mt'] for i in argkeys] # Loop over factorization terms block_contributions = collections.defaultdict(list) for ma_indices, fi in sorted(argument_factorization.items()): # Get a bunch of information about this term assert rank == len(ma_indices) trs = tuple(F.nodes[ai]['tr'] for ai in ma_indices) unames = tuple(tr.name for tr in trs) ttypes = tuple(tr.ttype for tr in trs) assert not any(tt == "zeros" for tt in ttypes) blockmap = tuple(tr.dofmap for tr in trs) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals block_restrictions = [] for i, ai in enumerate(ma_indices): if trs[i].is_uniform: r = None else: r = F.nodes[ai]['mt'].restriction block_restrictions.append(r) block_restrictions = tuple(block_restrictions) factor_is_piecewise = F.nodes[fi]['status'] == 'piecewise' # TODO: Add separate block modes for quadrature # Both arguments in quadrature elements """ for iq fw = f*w #for i # for j # B[i,j] = fw*U[i]*V[j] = 0 if i != iq or j != iq BQ[iq] = B[iq,iq] = fw for (iq) A[iq+offset0, iq+offset1] = BQ[iq] """ # One argument in quadrature element """ for iq fw[iq] = f*w #for i # for j # B[i,j] = fw*UQ[i]*V[j] = 0 if i != iq for j BQ[iq,j] = fw[iq]*V[iq,j] for (iq) for (j) A[iq+offset, j+offset] = BQ[iq,j] """ # Decide how to handle code generation for this block if p["enable_preintegration"] and (factor_is_piecewise and rank > 0 and "quadrature" not in ttypes): # - Piecewise factor is an absolute prerequisite # - Could work for rank 0 as well but currently doesn't # - Haven't considered how quadrature elements work out block_mode = "preintegrated" elif p["enable_premultiplication"] and (rank > 0 and all(tt in piecewise_ttypes for tt in ttypes)): # Integrate functional in quadloop, scale block after # quadloop block_mode = "premultiplied" elif p["enable_sum_factorization"]: if (rank == 2 and any(tt in piecewise_ttypes for tt in ttypes)): # Partial computation in quadloop of f*u[i], compute # (f*u[i])*v[i] outside quadloop, (or with u,v # swapped) block_mode = "partial" else: # Full runtime integration of f*u[i]*v[j], can still # do partial computation in quadloop of f*u[i] but # must compute (f*u[i])*v[i] as well inside # quadloop. (or with u,v swapped) block_mode = "full" else: # Use full runtime integration with nothing fancy going # on block_mode = "safe" # Carry out decision if block_mode == "preintegrated": # Add to contributions: # P = sum_q weight*u*v; preintegrated here # B[...] = f * P[...]; generated after quadloop # A[blockmap] += B[...]; generated after quadloop cache = ir["piecewise_ir"]["preintegrated_blocks"] block_is_transposed = False pname = cache.get(unames) # Reuse transpose to save memory if p["enable_block_transpose_reuse"] and pname is None and len(unames) == 2: pname = cache.get((unames[1], unames[0])) if pname is not None: # Cache hit on transpose block_is_transposed = True if pname is None: # Cache miss, precompute block weights = quadrature_rules[num_points][1] if integral_type == "interior_facet": ptable = integrate_block_interior_facets( weights, unames, ttypes, unique_tables, unique_table_num_dofs) else: ptable = integrate_block(weights, unames, ttypes, unique_tables, unique_table_num_dofs) ptable = clamp_table_small_numbers( ptable, rtol=p["table_rtol"], atol=p["table_atol"]) pname = "PI%d" % (len(cache, )) cache[unames] = pname unique_tables[pname] = ptable unique_table_types[pname] = "preintegrated" assert factor_is_piecewise block_unames = (pname, ) blockdata = block_data_t( block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, pname, None, None) block_is_piecewise = True elif block_mode == "premultiplied": # Add to contributions: # P = u*v; computed here # FI = sum_q weight * f; generated inside quadloop # B[...] = FI * P[...]; generated after quadloop # A[blockmap] += B[...]; generated after quadloop cache = ir["piecewise_ir"]["premultiplied_blocks"] block_is_transposed = False pname = cache.get(unames) # Reuse transpose to save memory if p["enable_block_transpose_reuse"] and pname is None and len(unames) == 2: pname = cache.get((unames[1], unames[0])) if pname is not None: # Cache hit on transpose block_is_transposed = True if pname is None: # Cache miss, precompute block if integral_type == "interior_facet": ptable = multiply_block_interior_facets(0, unames, ttypes, unique_tables, unique_table_num_dofs) else: ptable = multiply_block(0, unames, ttypes, unique_tables, unique_table_num_dofs) pname = "PM%d" % (len(cache, )) cache[unames] = pname unique_tables[pname] = ptable unique_table_types[pname] = "premultiplied" block_unames = (pname, ) blockdata = block_data_t( block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, pname, None, None) block_is_piecewise = False # elif block_mode == "scaled": # # TODO: Add mode, block is piecewise but choose not to be premultiplied # # Add to contributions: # # FI = sum_q weight * f; generated inside quadloop # # B[...] = FI * u * v; generated after quadloop # # A[blockmap] += B[...]; generated after quadloop # raise NotImplementedError("scaled block mode not implemented.") # # (probably need mostly the same data as # # premultiplied, except no P table name or values) # block_is_piecewise = False elif block_mode in ("partial", "full", "safe"): block_is_piecewise = factor_is_piecewise and not expect_weight ma_data = [] for i, ma in enumerate(ma_indices): if not trs[i].is_piecewise: block_is_piecewise = False ma_data.append(ma_data_t(ma, trs[i])) block_is_transposed = False # FIXME: Handle transposes for these block types if block_mode == "partial": # Add to contributions: # P[i] = sum_q weight * f * u[i]; generated inside quadloop # B[i,j] = P[i] * v[j]; generated after quadloop (where v is the piecewise ma) # A[blockmap] += B[...]; generated after quadloop # Find first piecewise index TODO: Is last better? just reverse range here for i in range(rank): if trs[i].is_piecewise: piecewise_ma_index = i break assert rank == 2 not_piecewise_ma_index = 1 - piecewise_ma_index block_unames = (unames[not_piecewise_ma_index], ) blockdata = block_data_t(block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, None, None, tuple(ma_data), piecewise_ma_index) elif block_mode in ("full", "safe"): # Add to contributions: # B[i] = sum_q weight * f * u[i] * v[j]; generated inside quadloop # A[blockmap] += B[i]; generated after quadloop block_unames = unames blockdata = block_data_t(block_mode, ttypes, fi, factor_is_piecewise, block_unames, block_restrictions, block_is_transposed, None, None, tuple(ma_data), None) else: raise RuntimeError("Invalid block_mode %s" % (block_mode, )) if block_is_piecewise: # Insert in piecewise expr_ir ir["piecewise_ir"]["block_contributions"][blockmap].append(blockdata) else: # Insert in varying expr_ir for this quadrature loop block_contributions[blockmap].append(blockdata) # Figure out which table names are referenced in unstructured # partition active_table_names = set() for i, v in F.nodes.items(): tr = v.get('tr') if tr is not None and F.nodes[i]['status'] != 'inactive': active_table_names.add(tr.name) # Figure out which table names are referenced in blocks for blockmap, contributions in itertools.chain( block_contributions.items(), ir["piecewise_ir"]["block_contributions"].items()): for blockdata in contributions: if blockdata.block_mode in ("preintegrated", "premultiplied"): active_table_names.add(blockdata.name) elif blockdata.block_mode in ("partial", "full", "safe"): for mad in blockdata.ma_data: active_table_names.add(mad.tabledata.name) # Record all table types before dropping tables ir["unique_table_types"].update(unique_table_types) # Drop tables not referenced from modified terminals # and tables of zeros and ones unused_ttypes = ("zeros", "ones", "quadrature") keep_table_names = set() for name in active_table_names: ttype = ir["unique_table_types"][name] if ttype not in unused_ttypes: if name in unique_tables: keep_table_names.add(name) unique_tables = {name: unique_tables[name] for name in keep_table_names} # Add to global set of all tables for name, table in unique_tables.items(): tbl = ir["unique_tables"].get(name) if tbl is not None and not numpy.allclose( tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]): raise RuntimeError("Table values mismatch with same name.") ir["unique_tables"].update(unique_tables) # Analyse active terminals to check what we'll need to generate code for active_mts = [] for i, v in F.nodes.items(): mt = v.get('mt', False) if mt and F.nodes[i]['status'] != 'inactive': active_mts.append(mt) # Figure out if we need to access CellCoordinate to avoid # generating quadrature point table otherwise if integral_type == "cell": need_points = any(isinstance(mt.terminal, CellCoordinate) for mt in active_mts) elif integral_type in facet_integral_types: need_points = any(isinstance(mt.terminal, FacetCoordinate) for mt in active_mts) elif integral_type in custom_integral_types: need_points = True # TODO: Always? else: need_points = False # Figure out if we need to access QuadratureWeight to avoid # generating quadrature point table otherwise need_weights = # any(isinstance(mt.terminal, QuadratureWeight) for mt in # active_mts) # Count blocks of each mode block_modes = collections.defaultdict(int) for blockmap, contributions in block_contributions.items(): for blockdata in contributions: block_modes[blockdata.block_mode] += 1 # Debug output summary = "\n".join( " {}\t{}".format(count, mode) for mode, count in sorted(block_modes.items())) logger.debug("Blocks of each mode: {}".format(summary)) # If there are any blocks other than preintegrated we need weights if expect_weight and any(mode != "preintegrated" for mode in block_modes): need_weights = True elif integral_type in custom_integral_types: need_weights = True # TODO: Always? else: need_weights = False # Build IR dict for the given expressions # Store final ir for this num_points ir["varying_irs"][num_points] = {"factorization": F, "modified_arguments": [F.nodes[i]['mt'] for i in argkeys], "block_contributions": block_contributions, "need_points": need_points, "need_weights": need_weights} return ir