def _add_unique_dim_name(name, dim_names): if dim_names is None: return dim_names from pytools import UniqueNameGenerator ng = UniqueNameGenerator(set(dim_names)) return (ng(name),) + tuple(dim_names)
def uniquify_instruction_ids(kernel): """Converts any ids that are :class:`loopy.UniqueName` or *None* into unique strings. This function does *not* deduplicate existing instruction ids. """ from loopy.kernel.creation import UniqueName insn_ids = set( insn.id for insn in kernel.instructions if insn.id is not None and not isinstance(insn.id, UniqueName)) from pytools import UniqueNameGenerator insn_id_gen = UniqueNameGenerator(insn_ids) new_instructions = [] for insn in kernel.instructions: if insn.id is None: new_instructions.append(insn.copy(id=insn_id_gen("insn"))) elif isinstance(insn.id, UniqueName): new_instructions.append(insn.copy(id=insn_id_gen(insn.id.name))) else: new_instructions.append(insn) return kernel.copy(instructions=new_instructions)
def disambiguate_identifiers(statements_a, statements_b, should_disambiguate_name=None): if should_disambiguate_name is None: def should_disambiguate_name(name): return True from pymbolic.imperative.analysis import get_all_used_identifiers id_a = get_all_used_identifiers(statements_a) id_b = get_all_used_identifiers(statements_b) from pytools import UniqueNameGenerator vng = UniqueNameGenerator(id_a | id_b) from pymbolic import var subst_b = {} for clash in id_a & id_b: if should_disambiguate_name(clash): unclash = vng(clash) subst_b[clash] = var(unclash) from pymbolic.mapper.substitutor import (make_subst_func, SubstitutionMapper) subst_map = SubstitutionMapper(make_subst_func(subst_b)) statements_b = [stmt.map_expressions(subst_map) for stmt in statements_b] return statements_b, subst_b
def map_roll(self, expr: Roll) -> Array: from pytato.utils import dim_to_index_lambda_components index_expr = var("_in0") indices = [var(f"_{d}") for d in range(expr.ndim)] axis = expr.axis axis_len_expr, bindings = dim_to_index_lambda_components( expr.shape[axis], UniqueNameGenerator({"_in0"})) indices[axis] = (indices[axis] - expr.shift) % axis_len_expr if indices: index_expr = index_expr[tuple(indices)] bindings["_in0"] = expr.array # type: ignore return IndexLambda(expr=index_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings={name: self.rec(bnd) for name, bnd in bindings.items()}, axes=expr.axes, tags=expr.tags)
def dim_to_index_lambda_components(expr: ShapeComponent, vng: Optional[UniqueNameGenerator] = None, ) -> Tuple[ScalarExpression, Dict[str, SizeParam]]: """ Returns the scalar expressions and bindings to use the shape component within an index lambda. .. testsetup:: >>> import pytato as pt >>> from pytato.utils import dim_to_index_lambda_components >>> from pytools import UniqueNameGenerator .. doctest:: >>> n = pt.make_size_param("n") >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 >>> bnds {'_in': SizeParam(name='n')} """ if isinstance(expr, INT_CLASSES): return expr, {} if vng is None: vng = UniqueNameGenerator() assert isinstance(vng, UniqueNameGenerator) assert isinstance(expr, Array) mapper = ShapeExpressionMapper(vng) result = mapper(expr) return result, mapper.bindings
def map_basic_index(self, expr: BasicIndex) -> IndexLambda: vng = UniqueNameGenerator() indices = [] in_ary = vng("in") bindings = {in_ary: self.rec(expr.array)} islice_idx = 0 for idx, axis_len in zip(expr.indices, expr.array.shape): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): indices.append(idx.start + idx.step * prim.Variable(f"_{islice_idx}")) islice_idx += 1 else: raise NotImplementedError return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=bindings, shape=expr.shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, )
def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: r"""Return a string in the `dot <https://graphviz.org>`_ language depicting the graph of the computation of *result*. :arg result: Outputs of the computation (cf. :func:`pytato.generate_loopy`). """ outputs: DictOfNamedArrays = normalize_outputs(result) del result mapper = ArrayToDotNodeInfoMapper() for elem in outputs._data.values(): mapper(elem) nodes = mapper.nodes input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] array_to_id: Dict[ArrayOrNames, str] = {} id_gen = UniqueNameGenerator() for array in nodes: array_to_id[array] = id_gen("array") if isinstance(array, InputArgumentBase): input_arrays.append(array) else: internal_arrays.append(array) emit = DotEmitter() with emit.block("digraph computation"): emit("node [shape=rectangle]") # Emit inputs. with emit.block("subgraph cluster_Inputs"): emit('label="Inputs"') for array in input_arrays: _emit_array(emit, nodes[array].title, nodes[array].fields, array_to_id[array]) # Emit non-inputs. for array in internal_arrays: _emit_array(emit, nodes[array].title, nodes[array].fields, array_to_id[array]) # Emit edges. for array, node in nodes.items(): for label, tail_array in node.edges.items(): tail = array_to_id[tail_array] head = array_to_id[array] emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) # Emit output/namespace name mappings. _emit_name_cluster(emit, outputs._data, array_to_id, id_gen, label="Outputs") return emit.get()
def __init__(self, discr, function_registry, prefix="_expr", max_vectors_in_batch_expr=None): super().__init__() self.prefix = prefix self.max_vectors_in_batch_expr = max_vectors_in_batch_expr self.discr_code = [] self.discr_scope_names_created = set() self.discr_scope_names_copied_to_eval = set() self.discr_expr_to_var = {} self.eval_code = [] self.eval_expr_to_var = {} self.assigned_names = set() self.discr = discr self.function_registry = function_registry from pytools import UniqueNameGenerator self.name_gen = UniqueNameGenerator()
def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: super().__init__() # Function to determine the part ID self._get_part_id: Callable[[ArrayOrNames], PartId] = \ get_part_id # Naming for newly created PlaceHolders at part edges from pytools import UniqueNameGenerator self.name_generator = UniqueNameGenerator(forced_prefix="_pt_part_ph_") # "edges" of the partitioned graph, maps an edge between two parts, # represented by a tuple of part identifiers, to a set of placeholder # names "conveying" information across the edge. self.part_pair_to_edges: Dict[Tuple[PartId, PartId], Set[str]] = {} self.var_name_to_result: Dict[str, Array] = {} self._seen_node_to_placeholder: Dict[ArrayOrNames, Placeholder] = {} # Reading the seen part IDs out of part_pair_to_edges is incorrect: # e.g. if each part is self-contained, no edges would appear. Instead, # we remember each part ID we see below, to guarantee that we don't # miss any of them. self.seen_part_ids: Set[PartId] = set() self.pid_to_user_input_names: Dict[PartId, Set[str]] = {}
def map_non_contiguous_advanced_index(self, expr: AdvancedIndexInNoncontiguousAxes ) -> IndexLambda: from pytato.utils import (get_shape_after_broadcasting, get_indexing_expression) i_adv_indices = tuple(i for i, idx_expr in enumerate(expr.indices) if isinstance(idx_expr, (Array, INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx] for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices = [] in_ary = vng("in") bindings = {in_ary: self.rec(expr.array)} islice_idx = len(adv_idx_shape) for idx, axis_len in zip(expr.indices, expr.array.shape): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") bindings[bnd_name] = self.rec(axis_len) indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): indices.append(idx.start + idx.step * prim.Variable(f"_{islice_idx}")) islice_idx += 1 elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") bindings[bnd_name] = self.rec(idx) indirect_idx_expr = prim.Subscript(prim.Variable(bnd_name), get_indexing_expression( idx.shape, adv_idx_shape)) if not idx.tags_of_type(AssumeNonNegative): indirect_idx_expr = indirect_idx_expr % axis_len indices.append(indirect_idx_expr) else: raise NotImplementedError("Advanced indexing over" " parametric axis lengths.") else: raise NotImplementedError(f"Indices of type {type(idx)}.") return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=bindings, shape=expr.shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, )
def __init__(self): self.indices = { } # indices for declarations and referencing values, from ImperoC self.active_indices = {} # gem index -> pymbolic variable self.index_extent = OrderedDict( ) # pymbolic variable for indices -> extent self.gem_to_pymbolic = {} # gem node -> pymbolic variable self.name_gen = UniqueNameGenerator()
def test_unique_name_gen_conflicting_ok(): from pytools import UniqueNameGenerator ung = UniqueNameGenerator() ung.add_names({"a", "b", "c"}) with pytest.raises(ValueError): ung.add_names({"a"}) ung.add_names({"a", "b", "c"}, conflicting_ok=True)
def __init__(self, coeffs): self.coefficients = coeffs self.inames = OrderedDict() self.needs_cell_orientations = False self.needs_cell_sizes = False self.needs_cell_facets = False self.needs_mesh_layers = False self.call_name_generator = UniqueNameGenerator(forced_prefix="tsfc_kernel_call_") self.index_creator = IndexCreator()
def __init__(self, array_context, mesh, order=None, quad_tag_to_group_factory=None, mpi_communicator=None): """ :param quad_tag_to_group_factory: A mapping from quadrature tags (typically strings--but may be any hashable/comparable object) to a :class:`~meshmode.discretization.poly_element.ElementGroupFactory` indicating with which quadrature discretization the operations are to be carried out, or *None* to indicate that operations with this quadrature tag should be carried out with the standard volume discretization. """ self._setup_actx = array_context from meshmode.discretization.poly_element import \ PolynomialWarpAndBlendGroupFactory if quad_tag_to_group_factory is None: if order is None: raise TypeError("one of 'order' and " "'quad_tag_to_group_factory' must be given") quad_tag_to_group_factory = { sym.QTAG_NONE: PolynomialWarpAndBlendGroupFactory(order=order)} else: if order is not None: quad_tag_to_group_factory = quad_tag_to_group_factory.copy() if sym.QTAG_NONE in quad_tag_to_group_factory: raise ValueError("if 'order' is given, " "'quad_tag_to_group_factory' must not have a " "key of QTAG_NONE") quad_tag_to_group_factory[sym.QTAG_NONE] = \ PolynomialWarpAndBlendGroupFactory(order=order) self.quad_tag_to_group_factory = quad_tag_to_group_factory from meshmode.discretization import Discretization self._volume_discr = Discretization(array_context, mesh, self.group_factory_for_quadrature_tag(sym.QTAG_NONE)) # {{{ management of discretization-scoped common subexpressions from pytools import UniqueNameGenerator self._discr_scoped_name_gen = UniqueNameGenerator() self._discr_scoped_subexpr_to_name = {} self._discr_scoped_subexpr_name_to_value = {} # }}} self._dist_boundary_connections = \ self._set_up_distributed_communication( mpi_communicator, array_context) self.mpi_communicator = mpi_communicator
def _gather_distributed_comm_info(partition: GraphPartition, pid_to_distributed_sends: Dict[PartId, List[DistributedSend]]) -> \ DistributedGraphPartition: var_name_to_result = {} parts: Dict[PartId, DistributedGraphPart] = {} dist_name_generator = UniqueNameGenerator(forced_prefix="_pt_dist_") for part in partition.parts.values(): comm_replacer = _DistributedCommReplacer(dist_name_generator) part_results = { var_name: comm_replacer(partition.var_name_to_result[var_name]) for var_name in part.output_names } dist_sends = [ comm_replacer.map_distributed_send(send) for send in pid_to_distributed_sends.get(part.pid, []) ] part_results.update({ name: send_node.data for name, send_node in comm_replacer.output_name_to_send_node.items() }) parts[part.pid] = DistributedGraphPart( pid=part.pid, needed_pids=part.needed_pids, user_input_names=part.user_input_names, partition_input_names=(part.partition_input_names | frozenset( comm_replacer.input_name_to_recv_node)), output_names=(part.output_names | frozenset(comm_replacer.output_name_to_send_node)), distributed_sends=dist_sends, input_name_to_recv_node=comm_replacer.input_name_to_recv_node, output_name_to_send_node=comm_replacer.output_name_to_send_node) for name, val in part_results.items(): assert name not in var_name_to_result var_name_to_result[name] = val result = DistributedGraphPartition( parts=parts, var_name_to_result=var_name_to_result, toposorted_part_ids=partition.toposorted_part_ids) if __debug__: # Check disjointness again since we replaced a few nodes. from pytato.partition import _check_partition_disjointness _check_partition_disjointness(result) return result
def rename_callable(program, old_name, new_name=None, existing_ok=False): """ :arg program: An instance of :class:`loopy.TranslationUnit` :arg old_name: The callable to be renamed :arg new_name: New name for the callable to be renamed :arg existing_ok: An instance of :class:`bool` """ from loopy.symbolic import (RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext) from pymbolic import var assert isinstance(program, TranslationUnit) assert isinstance(old_name, str) if (new_name in program.callables_table) and not existing_ok: raise LoopyError(f"callables named '{new_name}' already exists") if new_name is None: namegen = UniqueNameGenerator(program.callables_table.keys()) new_name = namegen(old_name) assert isinstance(new_name, str) new_callables_table = {} for name, clbl in program.callables_table.items(): if name == old_name: name = new_name if isinstance(clbl, CallableKernel): knl = clbl.subkernel rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, knl.get_var_name_generator()) smap = RuleAwareSubstitutionMapper(rule_mapping_context, {var(old_name): var(new_name)}.get, within=lambda *args: True) knl = rule_mapping_context.finish_kernel(smap.map_kernel(knl)) clbl = clbl.copy(subkernel=knl.copy(name=name)) elif isinstance(clbl, ScalarCallable): pass else: raise NotImplementedError(f"{type(clbl)}") new_callables_table[name] = clbl new_entrypoints = program.entrypoints.copy() if old_name in new_entrypoints: new_entrypoints = ((new_entrypoints | frozenset([new_name])) - frozenset([old_name])) return program.copy(callables_table=new_callables_table, entrypoints=new_entrypoints)
def __init__(self, start=None, forced_prefix="", key_translate_func=make_identifier_from_name, name_generator=None): if start is None: start = {} self._dict = dict(start) if name_generator is None: name_generator = UniqueNameGenerator(forced_prefix=forced_prefix) else: if forced_prefix: raise TypeError("passing 'forced_prefix' is not allowed when " "passing a pre-existing name generator") for existing_name in start.values(): if existing_name.startswith(name_generator.forced_prefix): name_generator.add_name(existing_name) self._generator = _KeyTranslatingUniqueNameGeneratorWrapper( name_generator, key_translate_func)
def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: from pytato.target.loopy import LoopyTarget if not isinstance(self.target, LoopyTarget): raise ValueError("Got a LoopyCall for a non-loopy target.") translation_unit = expr.translation_unit.copy( target=self.target.get_loopy_target()) namegen = UniqueNameGenerator(set(self.kernels_seen)) entrypoint = expr.entrypoint # {{{ eliminate callable name collision for name, clbl in translation_unit.callables_table.items(): if isinstance(clbl, lp.kernel.function_interface.CallableKernel): if name in self.kernels_seen and ( translation_unit[name] != self.kernels_seen[name]): # callee name collision => must rename # {{{ see if it's one of the other kernels for other_knl in self.kernels_seen.values(): if other_knl.copy(name=name) == translation_unit[name]: new_name = other_knl.name break else: # didn't find any other equivalent kernel, rename to # something unique new_name = namegen(name) # }}} if name == entrypoint: # if the colliding name is the entrypoint, then rename the # entrypoint as well. entrypoint = new_name translation_unit = lp.rename_callable( translation_unit, name, new_name) name = new_name self.kernels_seen[name] = translation_unit[name] # }}} bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())} return LoopyCall(translation_unit=translation_unit, bindings=bindings, entrypoint=entrypoint)
def __init__(self, domains=[], registered_substitutions=[], implicit_assignments={}, data=[], substs_to_arrays={}, name_generator=UniqueNameGenerator()): super(Stack, self).__init__(domains=domains, registered_substitutions=registered_substitutions, substs_to_arrays=substs_to_arrays, implicit_assignments=implicit_assignments, data=data, name_generator=name_generator)
class IndexCreator(object): inames = OrderedDict() # pym variable -> extent namer = UniqueNameGenerator(forced_prefix="i_") def __call__(self, extents): """Create new indices with specified extents. :arg extents. :class:`tuple` containting :class:`tuple` for extents of mixed tensors and :class:`int` for extents non-mixed tensor :returns: tuple of pymbolic Variable objects representing indices, contains tuples of Variables for mixed tensors and Variables for non-mixed tensors, where each Variable represents one extent.""" # Indices for scalar tensors extents += (1, ) if len(extents) == 0 else () # Stacked tuple = mixed tensor # -> loop over ext to generate idxs per block indices = [] if isinstance(extents[0], tuple): for ext_per_block in extents: idxs = self._create_indices(ext_per_block) indices.append(idxs) return tuple(indices) # Non-mixed tensors else: return self._create_indices(extents) def _create_indices(self, extents): """Create new indices with specified extents. :arg extents. :class:`tuple` or :class:`int` for extent of each index :returns: tuple of pymbolic Variable objects representing indices, one for each extent.""" indices = [] for ext in extents: name = self.namer() indices.append(pym.Variable(name)) self.inames[name] = int(ext) return tuple(indices) @property def domains(self): """ISL domains for the currently known indices.""" return create_domains(self.inames.items())
def __init__(self, cl_ctx, mesh, order, quad_tag_to_group_factory=None, mpi_communicator=None): """ :param quad_tag_to_group_factory: A mapping from quadrature tags (typically strings--but may be any hashable/comparable object) to a :class:`meshmode.discretization.ElementGroupFactory` indicating with which quadrature discretization the operations are to be carried out, or *None* to indicate that operations with this quadrature tag should be carried out with the standard volume discretization. """ if quad_tag_to_group_factory is None: quad_tag_to_group_factory = {} self.order = order self.quad_tag_to_group_factory = quad_tag_to_group_factory from meshmode.discretization import Discretization self._volume_discr = Discretization( cl_ctx, mesh, self.group_factory_for_quadrature_tag(sym.QTAG_NONE)) # {{{ management of discretization-scoped common subexpressions from pytools import UniqueNameGenerator self._discr_scoped_name_gen = UniqueNameGenerator() self._discr_scoped_subexpr_to_name = {} self._discr_scoped_subexpr_name_to_value = {} # }}} with cl.CommandQueue(cl_ctx) as queue: self._dist_boundary_connections = \ self._set_up_distributed_communication(mpi_communicator, queue) self.mpi_communicator = mpi_communicator
def fuse_statement_streams_with_unique_ids(statements_a, statements_b): new_statements = list(statements_a) from pytools import UniqueNameGenerator stmt_id_gen = UniqueNameGenerator( set([stmta.id for stmta in new_statements])) b_unique_statements = [] old_b_id_to_new_b_id = {} for stmtb in statements_b: old_id = stmtb.id new_id = stmt_id_gen(old_id) old_b_id_to_new_b_id[old_id] = new_id b_unique_statements.append(stmtb.copy(id=new_id)) for stmtb in b_unique_statements: new_statements.append( stmtb.copy(depends_on=frozenset(old_b_id_to_new_b_id[dep_id] for dep_id in stmtb.depends_on))) return new_statements, old_b_id_to_new_b_id
def substitute_into_domain(domain, param_name, expr, allowed_param_dims): """ :arg allowed_deps: A :class:`list` of :class:`str` that are """ import pymbolic.primitives as prim from loopy.symbolic import get_dependencies, isl_set_from_expr if param_name not in domain.get_var_dict(): # param_name not in domain => domain will be unchanged return domain # {{{ rename 'param_name' to avoid namespace pollution with allowed_param_dims dt, pos = domain.get_var_dict()[param_name] domain = domain.set_dim_name( dt, pos, UniqueNameGenerator(set(allowed_param_dims))(param_name)) # }}} for dep in get_dependencies(expr): if dep in allowed_param_dims: domain = domain.add_dims(isl.dim_type.param, 1) domain = domain.set_dim_name(isl.dim_type.param, domain.dim(isl.dim_type.param) - 1, dep) else: raise ValueError("Augmenting caller's domain " f"with '{dep}' is not allowed.") set_ = isl_set_from_expr( domain.space, prim.Comparison(prim.Variable(param_name), "==", expr)) bset, = set_.get_basic_sets() domain = domain & bset return domain.project_out(dt, pos, 1)
def __init__( self, domains, instructions, args=[], schedule=None, name="loopy_kernel", preambles=[], preamble_generators=[], assumptions=None, local_sizes={}, temporary_variables={}, iname_to_tag={}, substitutions={}, function_manglers=[ default_function_mangler, single_arg_function_mangler, ], symbol_manglers=[], iname_slab_increments={}, loop_priority=frozenset(), silenced_warnings=[], applied_iname_rewrites=[], cache_manager=None, index_dtype=np.int32, options=None, state=kernel_state.INITIAL, target=None, # When kernels get intersected in slab decomposition, # their grid sizes shouldn't change. This provides # a way to forward sub-kernel grid size requests. get_grid_sizes_for_insn_ids=None): if cache_manager is None: from loopy.kernel.tools import SetOperationCacheManager cache_manager = SetOperationCacheManager() # {{{ make instruction ids unique from loopy.kernel.creation import UniqueName insn_ids = set() for insn in instructions: if insn.id is not None and not isinstance(insn.id, UniqueName): if insn.id in insn_ids: raise RuntimeError("duplicate instruction id: %s" % insn.id) insn_ids.add(insn.id) insn_id_gen = UniqueNameGenerator(insn_ids) new_instructions = [] for insn in instructions: if insn.id is None: new_instructions.append(insn.copy(id=insn_id_gen("insn"))) elif isinstance(insn.id, UniqueName): new_instructions.append( insn.copy(id=insn_id_gen(insn.id.name))) else: new_instructions.append(insn) instructions = new_instructions del new_instructions # }}} # {{{ process assumptions if assumptions is None: dom0_space = domains[0].get_space() assumptions_space = isl.Space.params_alloc( dom0_space.get_ctx(), dom0_space.dim(dim_type.param)) for i in range(dom0_space.dim(dim_type.param)): assumptions_space = assumptions_space.set_dim_name( dim_type.param, i, dom0_space.get_dim_name(dim_type.param, i)) assumptions = isl.BasicSet.universe(assumptions_space) elif isinstance(assumptions, str): assumptions_set_str = "[%s] -> { : %s}" \ % (",".join(s for s in self.outer_params(domains)), assumptions) assumptions = isl.BasicSet.read_from_str(domains[0].get_ctx(), assumptions_set_str) assert assumptions.is_params() # }}} from loopy.types import to_loopy_type index_dtype = to_loopy_type(index_dtype, target=target) if not index_dtype.is_integral(): raise TypeError("index_dtype must be an integer") if np.iinfo(index_dtype.numpy_dtype).min >= 0: raise TypeError("index_dtype must be signed") if get_grid_sizes_for_insn_ids is not None: # overwrites method down below self.get_grid_sizes_for_insn_ids = get_grid_sizes_for_insn_ids if state not in [ kernel_state.INITIAL, kernel_state.PREPROCESSED, kernel_state.SCHEDULED, ]: raise ValueError("invalid value for 'state'") assert all(dom.get_ctx() == isl.DEFAULT_CONTEXT for dom in domains) assert assumptions.get_ctx() == isl.DEFAULT_CONTEXT ImmutableRecordWithoutPickling.__init__( self, domains=domains, instructions=instructions, args=args, schedule=schedule, name=name, preambles=preambles, preamble_generators=preamble_generators, assumptions=assumptions, iname_slab_increments=iname_slab_increments, loop_priority=loop_priority, silenced_warnings=silenced_warnings, temporary_variables=temporary_variables, local_sizes=local_sizes, iname_to_tag=iname_to_tag, substitutions=substitutions, cache_manager=cache_manager, applied_iname_rewrites=applied_iname_rewrites, function_manglers=function_manglers, symbol_manglers=symbol_manglers, index_dtype=index_dtype, options=options, state=state, target=target) self._kernel_executor_cache = {}
def __init__(self, target: Target) -> None: super().__init__() self.bound_arguments: Dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target self.kernels_seen: Dict[str, lp.LoopKernel] = {}
def map_einsum(self, expr: Einsum) -> Array: import operator from functools import reduce from pytato.scalar_expr import Reduce from pytato.utils import (dim_to_index_lambda_components, are_shape_components_equal) from pytato.array import ElementwiseAxis, ReductionAxis bindings = {f"in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} redn_bounds: Dict[str, Tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: List[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) # {{{ add bindings coming from the shape expressions for access_descr, (iarg, arg) in zip(expr.access_descriptors, enumerate(expr.args)): subscript_indices = [] for iaxis, axis in enumerate(access_descr): if not are_shape_components_equal( arg.shape[iaxis], expr._access_descr_to_axis_len()[axis]): # axis is broadcasted assert are_shape_components_equal(arg.shape[iaxis], 1) subscript_indices.append(0) continue if isinstance(axis, ElementwiseAxis): subscript_indices.append(prim.Variable(f"_{axis.dim}")) else: assert isinstance(axis, ReductionAxis) redn_idx_name = f"_r{axis.dim}" if redn_idx_name not in redn_bounds: # convert the ShapeComponent to a ScalarExpression redn_bound, redn_bound_bindings = ( dim_to_index_lambda_components( arg.shape[iaxis], namegen)) redn_bounds[redn_idx_name] = (0, redn_bound) bindings.update({k: self.rec(v) for k, v in redn_bound_bindings.items()}) subscript_indices.append(prim.Variable(redn_idx_name)) args_as_pym_expr.append(prim.Subscript(prim.Variable(f"in{iarg}"), tuple(subscript_indices))) # }}} inner_expr = reduce(operator.mul, args_as_pym_expr[1:], args_as_pym_expr[0]) if redn_bounds: from pytato.reductions import SumReductionOperation inner_expr = Reduce(inner_expr, SumReductionOperation(), redn_bounds) return IndexLambda(expr=inner_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings, axes=expr.axes, tags=expr.tags)
def __init__( self, array_context: ArrayContext, mesh: Mesh, order=None, discr_tag_to_group_factory=None, mpi_communicator=None, # FIXME: `quad_tag_to_group_factory` is deprecated quad_tag_to_group_factory=None): """ :arg discr_tag_to_group_factory: A mapping from discretization tags (typically one of: :class:`grudge.dof_desc.DISCR_TAG_BASE`, :class:`grudge.dof_desc.DISCR_TAG_MODAL`, or :class:`grudge.dof_desc.DISCR_TAG_QUAD`) to a :class:`~meshmode.discretization.poly_element.ElementGroupFactory` indicating with which type of discretization the operations are to be carried out, or *None* to indicate that operations with this discretization tag should be carried out with the standard volume discretization. """ if (quad_tag_to_group_factory is not None and discr_tag_to_group_factory is not None): raise ValueError( "Both `quad_tag_to_group_factory` and `discr_tag_to_group_factory` " "are specified. Use `discr_tag_to_group_factory` instead.") # FIXME: `quad_tag_to_group_factory` is deprecated if (quad_tag_to_group_factory is not None and discr_tag_to_group_factory is None): warn( "`quad_tag_to_group_factory` is a deprecated kwarg and will " "be dropped in version 2022.x. Use `discr_tag_to_group_factory` " "instead.", DeprecationWarning, stacklevel=2) discr_tag_to_group_factory = quad_tag_to_group_factory self._setup_actx = array_context.clone() from meshmode.discretization.poly_element import \ default_simplex_group_factory if discr_tag_to_group_factory is None: if order is None: raise TypeError( "one of 'order' and 'discr_tag_to_group_factory' must be given" ) discr_tag_to_group_factory = { DISCR_TAG_BASE: default_simplex_group_factory(base_dim=mesh.dim, order=order) } else: if order is not None: discr_tag_to_group_factory = discr_tag_to_group_factory.copy() if DISCR_TAG_BASE in discr_tag_to_group_factory: raise ValueError( "if 'order' is given, 'discr_tag_to_group_factory' must " "not have a key of DISCR_TAG_BASE") discr_tag_to_group_factory[DISCR_TAG_BASE] = \ default_simplex_group_factory(base_dim=mesh.dim, order=order) # Modal discr should always come from the base discretization discr_tag_to_group_factory[DISCR_TAG_MODAL] = \ _generate_modal_group_factory( discr_tag_to_group_factory[DISCR_TAG_BASE] ) self.discr_tag_to_group_factory = discr_tag_to_group_factory from meshmode.discretization import Discretization self._volume_discr = Discretization( array_context, mesh, self.group_factory_for_discretization_tag(DISCR_TAG_BASE)) # NOTE: Can be removed when symbolics are completely removed # {{{ management of discretization-scoped common subexpressions from pytools import UniqueNameGenerator self._discr_scoped_name_gen = UniqueNameGenerator() self._discr_scoped_subexpr_to_name = {} self._discr_scoped_subexpr_name_to_value = {} # }}} self._dist_boundary_connections = \ self._set_up_distributed_communication( mpi_communicator, array_context) self.mpi_communicator = mpi_communicator
def __init__(self, system_args): self.system_args = system_args[:] from pytools import UniqueNameGenerator self.dtype_name_generator = UniqueNameGenerator(forced_prefix="_lpy_dtype_") self.dtype_str_to_name = {}
def emit_adams_method(self, cb, explainer): from pytools import UniqueNameGenerator name_gen = UniqueNameGenerator() array = var("<builtin>array") # {{{ make temporary copies of time/hist_vars # maps from (component_name, irhs) to latest-last list of values temp_hist_substeps = {} temp_time_vars = {} temp_hist_vars = {} def fill_temp_hist_vars(): for comp_name, component_rhss in zip(self.component_names, self.rhss): for irhs, rhs in enumerate(component_rhss): key = comp_name, irhs temp_hist_substeps[key] = list( range(-rhs.interval * (rhs.history_length - 1), 1, rhs.interval)) if self.static_dt: temp_time_vars[key] = list( rhs.interval * i / self.nsubsteps for i in range(-rhs.history_length + 1, 0 + 1)) else: temp_time_vars[key] = self.time_vars[key][:] temp_hist_vars[key] = self.history_vars[key][:] fill_temp_hist_vars() # }}} def log_hist_state(): explainer.log_hist_state({ rhs.func_name: (temp_hist_substeps[comp_name, irhs][-rhs.history_length::], [ v.name for v in temp_hist_vars[comp_name, irhs][-rhs.history_length::] ]) for comp_name, component_rhss in zip(self.component_names, self.rhss) for irhs, rhs in enumerate(component_rhss) }) log_hist_state() # A mapping from component_name to a list of tuples # (substep_level, state_var). This mapping is ordered # by substep_level. computed_states = { comp_name: [(0, state_var)] for comp_name, state_var in zip(self.component_names, self.state_vars) } # {{{ get_state def get_state(comp_name, isubstep): states = computed_states[comp_name] # {{{ see if we've got that state ready to go for istate_substep, state_var in states: if istate_substep == isubstep: return state_var # }}} latest_state_substep, latest_state = states[-1] comp_index = self.component_names.index(comp_name) rhss = self.rhss[comp_index] contribs = [] contrib_explanations = [] for irhs, rhs in enumerate(rhss): hist_len = rhs.history_length relv_hist_substeps = temp_hist_substeps[comp_name, irhs][-hist_len:] relv_time_hist = temp_time_vars[comp_name, irhs][-hist_len:] relv_hist_vars = temp_hist_vars[comp_name, irhs][-hist_len:] t_start = latest_state_substep / self.nsubsteps t_end = isubstep / self.nsubsteps if not self.static_dt: time_hist_var = var(name_gen("time_hist")) cb(time_hist_var, array(hist_len)) for ii in range(hist_len): cb(time_hist_var[ii], relv_time_hist[ii] - self.t) time_hist = time_hist_var t_start *= self.dt t_end *= self.dt dt_factor = 1 else: time_hist = relv_time_hist dt_factor = self.dt from leap.multistep import ( AdamsMonomialIntegrationFunctionFamily, emit_adams_integration, emit_adams_extrapolation) if self.is_ode_component[comp_name]: contrib = dt_factor * emit_adams_integration( cb, name_gen, AdamsMonomialIntegrationFunctionFamily(rhs.order), time_hist, relv_hist_vars, t_start, t_end) else: contrib = emit_adams_extrapolation( cb, name_gen, AdamsMonomialIntegrationFunctionFamily(rhs.order), time_hist, relv_hist_vars, t_end) contribs.append(contrib) contrib_explanations.append( self.StateContribExplanation( rhs=rhs.func_name, from_substeps=relv_hist_substeps, using=relv_hist_vars)) state_var = var( name_gen("state_{comp_name}_sub{isubstep}".format( comp_name=comp_name, isubstep=isubstep))) if self.is_ode_component[comp_name]: state_expr = latest_state + sum(contribs) else: state_expr = sum(contribs) if comp_name in self.state_filters: state_expr = self.state_filters[comp_name](state_expr) cb(state_var, state_expr) # Only keep temporary state if integrates exactly # one interval ahead for the fastest right-hand side, # which is the expected rate. # # - If it integrates further, it's a poor-quality # extrapolation that should probably not be reused. # # - If it integrates less far, then by definition it is # not used for any state updates, and we don't gain # anything by keeping the temporary around, since the # same extrapolation can be recomputed. keep_temp_state = (isubstep - latest_state_substep == min( rhs.interval for rhs in rhss)) if keep_temp_state: states.append((isubstep, state_var)) if self.is_ode_component[comp_name]: explainer.integrate_to(comp_name, state_var.name, latest_state_substep, isubstep, latest_state, contrib_explanations) else: explainer.extrapolate_to(comp_name, state_var.name, latest_state_substep, isubstep, latest_state, contrib_explanations) return state_var # }}} # {{{ update_hist def update_hist(comp_idx, irhs, isubstep): comp_name = self.component_names[comp_idx] rhs = self.rhss[comp_idx][irhs] # {{{ get arguments together progress_frac = isubstep / self.nsubsteps t_expr = self.t + self.dt * progress_frac kwargs = { self.comp_name_to_kwarg_name[arg_comp_name]: get_state(arg_comp_name, isubstep) for arg_comp_name in rhs.arguments } # }}} rhs_var = var( name_gen("rhs_{comp_name}_rhs{irhs}_sub{isubstep}".format( comp_name=comp_name, irhs=irhs, isubstep=isubstep))) cb(rhs_var, var(rhs.func_name)(t=t_expr, **kwargs)) temp_hist_substeps[comp_name, irhs].append(isubstep) if not self.static_dt: t_var = var( name_gen("t_{comp_name}_rhs{irhs}_sub{isubstep}".format( comp_name=comp_name, irhs=irhs, isubstep=isubstep))) cb(t_var, t_expr) temp_time_vars[comp_name, irhs].append(t_var) else: temp_time_vars[comp_name, irhs].append(progress_frac) temp_hist_vars[comp_name, irhs].append(rhs_var) explainer.eval_rhs(rhs_var.name, comp_name, rhs.func_name, isubstep, kwargs) # {{{ invalidate computed states, if requested if rhs.invalidate_computed_state: for other_comp_name, other_component_rhss in zip( self.component_names, self.rhss): do_invalidate = False for _other_rhs in enumerate(other_component_rhss): if comp_name in rhs.arguments: do_invalidate = True break if do_invalidate: computed_states[other_comp_name][:] = [ (istate_substep, state) for istate_substep, state in computed_states[other_comp_name] # Only earlier states live. if istate_substep < isubstep ] # }}} # }}} def norm(expr): return var("<builtin>norm_2")(expr) def check_history_consistency(): # At the start of a macrostep, ensure that the last computed # RHS history corresponds to the current state for comp_name, component_rhss in zip(self.component_names, self.rhss): for irhs, rhs in enumerate(component_rhss): t_expr = self.t kwargs = { self.comp_name_to_kwarg_name[arg_comp_name]: get_state(arg_comp_name, 0) for arg_comp_name in rhs.arguments } test_rhs_var = var( name_gen("test_rhs_{comp_name}_rhs{irhs}_0".format( comp_name=comp_name, irhs=irhs))) cb(test_rhs_var, var(rhs.func_name)(t=t_expr, **kwargs)) # Compare this computed RHS with the 0th history point using # built-in norm. zeroth_hist = temp_hist_vars[comp_name, irhs][-1] rel_rhs_error = ( norm(test_rhs_var - zeroth_hist) / # noqa: W504 norm(test_rhs_var)) cb("rel_rhs_error", rel_rhs_error) # cb((), "<builtin>print(rel_rhs_error)") if rhs.rhs_policy == rhs_policy.early: # Check for scheme-order accuracy if self.early_hist_consistency_threshold is not None: with cb.if_("rel_rhs_error", ">=", self.early_hist_consistency_threshold): cb((), "<builtin>print(rel_rhs_error)") cb.raise_( InconsistentHistoryError, "MRAdams: top-of-history for RHS " f"'{rhs.func_name}' is not " "consistent with current state") else: cb.raise_( InconsistentHistoryError, "MRAdams: RHS '{rhs.func_name}' has early " "policy and requires relaxed threshold input") else: # Check for floating-point accuracy with cb.if_("rel_rhs_error", ">=", self.hist_consistency_threshold): cb.raise_( InconsistentHistoryError, "MRAdams: top-of-history for RHS " f"'{rhs.func_name}' is not " "consistent with current state") # {{{ run_substep_loop def run_substep_loop(): # Check last history value from previous macrostep if self.hist_consistency_threshold is not None: check_history_consistency() for isubstep in range(self.nsubsteps + 1): for comp_idx, (comp_name, component_rhss) in enumerate( zip(self.component_names, self.rhss)): for irhs, rhs in enumerate(component_rhss): if isubstep % rhs.interval != 0: continue if isubstep > 0: # {{{ finish up prior step if rhs.rhs_policy == rhs_policy.early_and_late: temp_hist_substeps[comp_name, irhs].pop() temp_time_vars[comp_name, irhs].pop() temp_hist_vars[comp_name, irhs].pop() explainer.roll_back_history(rhs.func_name) if rhs.rhs_policy in [ rhs_policy.early_and_late, rhs_policy.late ]: update_hist(comp_idx, irhs, isubstep) # }}} if isubstep < self.nsubsteps: # {{{ start up a new substep if rhs.rhs_policy in [ rhs_policy.early, rhs_policy.early_and_late ]: update_hist(comp_idx, irhs, isubstep + rhs.interval) # }}} run_substep_loop() # }}} log_hist_state() end_states = [ get_state(component_name, self.nsubsteps) for component_name in self.component_names ] # {{{ commit temp history to permanent history def commit_temp_hist_vars(): for comp_name, component_rhss in zip(self.component_names, self.rhss): for irhs, rhs in enumerate(component_rhss): key = comp_name, irhs if not self.static_dt: for time_var, time_expr in zip( self.time_vars[key], temp_time_vars[comp_name, irhs][-rhs.history_length:]): cb(time_var, time_expr) for hist_var, hist_expr in zip( self.history_vars[key], temp_hist_vars[comp_name, irhs][-rhs.history_length:]): cb(hist_var, hist_expr) commit_temp_hist_vars() # }}} # TODO: Figure out more spots to yield intermediate state for component_name, state in zip(self.component_names, end_states): if self.is_ode_component[component_name]: cb.yield_state(state, component_name, self.t + self.dt, "final") cb(var("<state>" + component_name), state) cb(self.t, self.t + self.dt)
def get_instruction_id_generator(self, based_on="insn"): used_ids = set(insn.id for insn in self.instructions) return UniqueNameGenerator(used_ids)