def split_linear(self, functions): """ Applies linearity property of Diff: i.e. 'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)' The parameter functions is a list of all symbols that are considered functions, not constants. For the example above: functions=[a, b] """ constant, variable = 1, 1 if self.arg.func != sp.Mul: constant, variable = 1, self.arg else: for factor in normalize_product(self.arg): if factor in functions or isinstance(factor, Diff): variable *= factor else: constant *= factor if isinstance(variable, sp.Symbol) and variable not in functions: return 0 if isinstance(variable, int) or variable.is_number: return 0 else: return constant * Diff( variable, target=self.target, superscript=self.superscript)
def expand_diff_products(expr): """Fully expands all derivatives by applying product rule""" if isinstance(expr, Diff): arg = expand_diff_products(expr.args[0]) if arg.func == sp.Add: new_args = [ Diff(e, target=expr.target, superscript=expr.superscript) for e in arg.args ] return sp.Add(*new_args) if arg.func not in (sp.Mul, sp.Pow): return Diff(arg, target=expr.target, superscript=expr.superscript) else: prod_list = normalize_product(arg) result = 0 for i in range(len(prod_list)): pre_factor = prod(prod_list[j] for j in range(len(prod_list)) if i != j) result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript) return result else: new_args = [expand_diff_products(e) for e in expr.args] return expr.func(*new_args) if new_args else expr
def handle_product(product_term): f_index = None derivative_term = None c_indices = [] rest = 1 for factor in normalize_product(product_term): if isinstance(factor, Diff): assert f_index is None f_index = determine_f_index(factor.get_arg_recursive()) derivative_term = factor elif factor in velocity_terms: c_indices += [velocity_terms.index(factor)] else: new_f_index = determine_f_index(factor) if new_f_index is None: rest *= factor else: assert not (new_f_index and f_index) f_index = new_f_index moment_tuple = [0] * len(velocity_terms) for c_idx in c_indices: moment_tuple[c_idx] += 1 moment_tuple = tuple(moment_tuple) if use_one_neighborhood_aliasing: moment_tuple = non_aliased_moment(moment_tuple) result = CeMoment(f_index.moment_name, moment_tuple, f_index.superscript) if derivative_term is not None: result = derivative_term.change_arg_recursive(result) result *= rest return result
def count_vars(expr, variables): factor_list = normalize_product(expr) diffs_to_unpack = [e for e in factor_list if isinstance(e, Diff)] factor_list = [e for e in factor_list if not isinstance(e, Diff)] while diffs_to_unpack: d = diffs_to_unpack.pop() args = normalize_product(d.arg) for a in args: if isinstance(a, Diff): diffs_to_unpack.append(a) else: factor_list.append(a) result = 0 for v in variables: result += factor_list.count(v) return result
def handle_mul(mul): args = normalize_product(mul) diffs = [a for a in args if isinstance(a, DiffOperator)] if len(diffs) == 0: return mul * argument if apply_to_constants else mul rest = [a for a in args if not isinstance(a, DiffOperator)] diffs.sort(key=_default_diff_sort_key) result = argument for d in reversed(diffs): result = Diff(result, target=d.target, superscript=d.superscript) return prod(rest) * result
def _compute_moments(self, recombined_eq, symbols_to_values): eq = recombined_eq.expand() assert eq.func is sp.Add new_products = [] for product in eq.args: assert product.func is sp.Mul derivative = None new_prod = 1 for arg in reversed(normalize_product(product)): if isinstance(arg, Diff): assert derivative is None, "More than one derivative term in the product" derivative = arg arg = arg.get_arg_recursive( ) # new argument is inner part of derivative if arg in symbols_to_values: arg = symbols_to_values[arg] have_shape = hasattr(arg, 'shape') and hasattr( new_prod, 'shape') if have_shape and arg.shape == new_prod.shape and arg.shape[ 1] == 1: # since sympy 1.9 sp.matrix_multiply_elementwise does not work anymore in this case new_prod = sp.Matrix(np.multiply(new_prod, arg)) else: new_prod = arg * new_prod if new_prod == 0: break if new_prod == 0: continue new_prod = sp.expand(sum(new_prod)) if derivative is not None: new_prod = derivative.change_arg_recursive(new_prod) new_products.append(new_prod) return normalize_diff_order( expand_diff_linear(sp.Add(*new_products), functions=self.physical_variables))
def extract_gamma(free_energy, order_parameters): """Extracts parameters before the gradient terms""" result = defaultdict(lambda: 0) free_energy = free_energy.expand() assert free_energy.func == sp.Add for product in free_energy.args: product = normalize_product(product) diff_factors = [e for e in product if e.func == Diff] if len(diff_factors) == 0: continue if len(diff_factors) != 2: raise ValueError(f"Could not determine Λ because of term {str(product)}") indices = sorted([order_parameters.index(d.args[0]) for d in diff_factors]) increment = prod(e for e in product if e.func != Diff) if diff_factors[0] == diff_factors[1]: increment *= 2 result[tuple(indices)] += increment return result
def expr_to_diff_decomposition(expression): """Decomposes a sp.Add node containing CeDiffs into: diff_dict: maps (target, superscript) -> [ (pre_factor, argument), ... ] i.e. a partial(b) ( a is pre-factor, b is argument) in case of partial(a) partial(b) two entries are created (0.5 partial(a), b), (0.5 partial(b), a) """ DiffInfo = namedtuple("DiffInfo", ["target", "superscript"]) class DiffSplit: def __init__(self, fac, argument): self.pre_factor = fac self.argument = argument def __repr__(self): return str((self.pre_factor, self.argument)) assert isinstance(expression, sp.Add) diff_dict = defaultdict(list) rest = 0 for term in expression.args: if isinstance(term, Diff): diff_dict[DiffInfo(term.target, term.superscript)].append( DiffSplit(1, term.arg)) else: mul_args = normalize_product(term) diffs = [d for d in mul_args if isinstance(d, Diff)] factor = prod(d for d in mul_args if not isinstance(d, Diff)) if len(diffs) == 0: rest += factor else: for i, diff in enumerate(diffs): all_but_current = [ d for j, d in enumerate(diffs) if i != j ] pre_factor = factor * prod( all_but_current) * sp.Rational(1, len(diffs)) diff_dict[DiffInfo(diff.target, diff.superscript)].append( DiffSplit(pre_factor, diff.arg)) return diff_dict, rest
def visit(e): if not isinstance(e, sp.Tuple): e = e.expand() if e.func == Diff: result = 0 diff_args = {'target': e.target, 'superscript': e.superscript} diff_inner = e.args[0] diff_inner = visit(diff_inner) if diff_inner.func not in (sp.Add, sp.Mul): return e for term in diff_inner.args if diff_inner.func == sp.Add else [ diff_inner ]: independent_terms = 1 dependent_terms = [] for factor in normalize_product(term): if factor in functions or isinstance(factor, Diff): dependent_terms.append(factor) else: independent_terms *= factor for i in range(len(dependent_terms)): dependent_term = dependent_terms[i] other_dependent_terms = dependent_terms[: i] + dependent_terms[ i + 1:] processed_diff = normalize_diff_order( Diff(dependent_term, **diff_args)) result += independent_terms * prod( other_dependent_terms) * processed_diff return result elif isinstance(e, sp.Piecewise): return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args)) elif isinstance(expr, sp.Tuple): new_args = [visit(arg) for arg in e.args] return sp.Tuple(*new_args) else: new_args = [visit(arg) for arg in e.args] return e.func(*new_args) if new_args else e
def handle_postcollision_values(expr): expr = expr.expand() assert isinstance(expr, sp.Add) result = 0 for summand in expr.args: moment = summand.atoms(CeMoment) moment = moment.pop() collision_operator_exponent = normalize_product(summand).count( collision_operator) if collision_operator_exponent == 0: result += summand else: substitutions = { collision_operator: 1, moment: -moment_computation.get_post_collision_moment( moment, -collision_operator_exponent), } result += summand.subs(substitutions) return result