def callback(): # Derivatives must be evaluated before the introduction of indirect accesses try: _expr = expr.evaluate except AttributeError: # E.g., a generic SymPy expression or a number _expr = expr variables = list(retrieve_function_carriers(_expr)) # Need to get origin of the field in case it is staggered # TODO: handle each variable staggereing spearately field_offset = variables[0].origin # List of indirection indices for all adjacent grid points idx_subs, temps = self._interpolation_indices(variables, offset, field_offset=field_offset) # Substitute coordinate base symbols into the interpolation coefficients args = [_expr.xreplace(v_sub) * b.xreplace(v_sub) for b, v_sub in zip(self._interpolation_coeffs, idx_subs)] # Accumulate point-wise contributions into a temporary rhs = Symbol(name='sum', dtype=self.sfunction.dtype) summands = [Eq(rhs, 0., implicit_dims=self.sfunction.dimensions)] summands.extend([Inc(rhs, i, implicit_dims=self.sfunction.dimensions) for i in args]) # Write/Incr `self` lhs = self.sfunction.subs(self_subs) last = [Inc(lhs, rhs)] if increment else [Eq(lhs, rhs)] return temps + summands + last
def callback(): # Derivatives must be evaluated before the introduction of indirect accesses try: _expr = expr.evaluate except AttributeError: # E.g., a generic SymPy expression or a number _expr = expr variables = list(retrieve_function_carriers(_expr)) + [field] # Need to get origin of the field in case it is staggered field_offset = field.origin # List of indirection indices for all adjacent grid points idx_subs, temps = self._interpolation_indices( variables, offset, field_offset=field_offset) # Substitute coordinate base symbols into the interpolation coefficients eqns = [ Inc(field.xreplace(vsub), _expr.xreplace(vsub) * b, implicit_dims=self.sfunction.dimensions) for b, vsub in zip(self._interpolation_coeffs, idx_subs) ] return temps + eqns