def step_statements(self, stage, f, dt, rhs): fq = [index_fields(f, prepend_with=(q, )) for q in range(2)] if stage == 0: return {fq[1]: fq[0] + dt * rhs, fq[0]: fq[0] + dt / 2 * rhs} elif stage == 1: return {fq[0]: fq[0] + dt / 2 * rhs}
def step_statements(self, stage, f, dt, rhs): fq = [index_fields(f, prepend_with=(q, )) for q in range(3)] if stage == 0: return {fq[1]: fq[0] + dt * rhs} elif stage == 1: return {fq[1]: 3 / 4 * fq[0] + 1 / 4 * fq[1] + dt / 4 * rhs} elif stage == 2: return {fq[0]: 1 / 3 * fq[0] + 2 / 3 * fq[1] + dt * 2 / 3 * rhs}
def make_steps(self, MapKernel=ElementWiseMap, **kwargs): rhs = var("rhs") dt = var("dt") q = var("q") fixed_parameters = kwargs.pop("fixed_parameters", dict()) rhs_statements = { rhs[i]: index_fields(value, prepend_with=(q, )) for i, value in enumerate(self.rhs_dict.values()) } steps = [] for stage in range(self.num_stages): RK_dict = {} for i, f in enumerate(self.rhs_dict.keys()): # ensure that key is either a Field or a Subscript of a Field # so that index_fields can prepend the q index key_has_field = False if isinstance(f, Field): key_has_field = True elif isinstance(f, Subscript): if isinstance(f.aggregate, Field): key_has_field = True if not key_has_field: raise ValueError("rhs_dict keys must be Field instances") statements = self.step_statements(stage, f, dt, rhs[i]) for k, v in statements.items(): RK_dict[k] = v fixed_parameters.update(q=0 if stage == 0 else 1) options = lp.Options(enforce_variable_access_ordered="no_check") step = MapKernel(RK_dict, tmp_instructions=rhs_statements, args=self.args, **kwargs, options=options, fixed_parameters=fixed_parameters) steps.append(step) return steps
def make_kernel(self, map_instructions, tmp_instructions, args, domains, **kwargs): temp_statements = [] temp_vars = [] from pystella.field import index_fields indexed_tmp_insns = index_fields(tmp_instructions) indexed_map_insns = index_fields(map_instructions) for statement in indexed_tmp_insns: if isinstance(statement, lp.InstructionBase): temp_statements += [statement] else: assignee, expression = statement # only declare temporary variables once if isinstance(assignee, pp.Variable): current_tmp = assignee elif isinstance(assignee, pp.Subscript): current_tmp = assignee.aggregate else: current_tmp = None if current_tmp is not None and current_tmp not in temp_vars: temp_vars += [current_tmp] tvt = lp.Optional(None) else: tvt = lp.Optional() temp_statements += [ self._assignment(assignee, expression, temp_var_type=tvt) ] output_statements = [] for statement in indexed_map_insns: if isinstance(statement, lp.InstructionBase): output_statements += [statement] else: assignee, expression = statement temp_statements += [self._assignment(assignee, expression)] options = kwargs.pop("options", lp.Options()) # ignore lack of supposed dependency for single-instruction kernels if len(map_instructions) + len(tmp_instructions) == 1: options.check_dep_resolution = False from pystella import get_field_args inferred_args = get_field_args([map_instructions, tmp_instructions]) all_args = append_new_args(args, inferred_args) t_unit = lp.make_kernel( domains, temp_statements + output_statements, all_args + [lp.ValueArg("Nx, Ny, Nz", dtype="int"), ...], options=options, **kwargs, ) new_args = [] knl = t_unit.default_entrypoint for arg in knl.args: if isinstance(arg, lp.KernelArgument) and arg.dtype is None: new_arg = arg.copy(dtype=self.dtype) new_args.append(new_arg) else: new_args.append(arg) t_unit = t_unit.with_kernel(knl.copy(args=new_args)) t_unit = lp.remove_unused_arguments(t_unit) t_unit = lp.register_callable(t_unit, "round", UnaryOpenCLCallable("round")) return t_unit