Esempio n. 1
0
 def visit_node(node, substitution_dict, default_type='double'):
     substitution_dict = substitution_dict.copy()
     for arg in node.args:
         if isinstance(arg, ast.SympyAssignment):
             assignment = arg
             subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                   skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
             assignment.rhs = visit_expr(subs_expr, default_type)
             rhs_type = get_type_of_expression(assignment.rhs)
             if isinstance(assignment.lhs, TypedSymbol):
                 lhs_type = assignment.lhs.dtype
                 if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
                     new_lhs_type = VectorType(lhs_type, rhs_type.width)
                     new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                     substitution_dict[assignment.lhs] = new_lhs
                     assignment.lhs = new_lhs
             elif isinstance(assignment.lhs, vector_memory_access):
                 assignment.lhs = visit_expr(assignment.lhs, default_type)
         elif isinstance(arg, ast.Conditional):
             arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
                                            skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
             arg.condition_expr = visit_expr(arg.condition_expr, default_type)
             visit_node(arg, substitution_dict, default_type)
         else:
             visit_node(arg, substitution_dict, default_type)
Esempio n. 2
0
    def _simplify_lower_order_moments(self, ac, moment_base, search_in_main_assignments):
        if self.cqe is None:
            return ac

        moment_symbols = [sq_sym(moment_base, e) for e in moments_up_to_order(1, dim=self.dim)]

        if search_in_main_assignments:
            f_to_cm_dict = ac.main_assignments_dict
            f_to_cm_dict_reduced = ac.new_without_subexpressions().main_assignments_dict
        else:
            f_to_cm_dict = ac.subexpressions_dict
            f_to_cm_dict_reduced = ac.new_without_subexpressions(moment_symbols).subexpressions_dict

        cqe_subs = self.cqe.new_without_subexpressions().main_assignments_dict
        for m in moment_symbols:
            m_eq = fast_subs(fast_subs(f_to_cm_dict_reduced[m], cqe_subs), cqe_subs)
            m_eq = m_eq.expand().cancel()
            for cqe_sym, cqe_exp in cqe_subs.items():
                m_eq = subs_additive(m_eq, cqe_sym, cqe_exp)
            f_to_cm_dict[m] = m_eq

        if search_in_main_assignments:
            main_assignments = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()]
            return ac.copy(main_assignments=main_assignments)
        else:
            subexpressions = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()]
            return ac.copy(subexpressions=subexpressions)
Esempio n. 3
0
def __cumulant_raw_moment_transform(index, dependent_var_dict, outer_function,
                                    default_prefix, centralized):
    """Function to express cumulants as function of moments and vice versa.

    Uses multivariate version of Faa di Bruno's formula.

    Args:
        index: tuple describing the index of the cumulant/moment to express as function of moments/cumulants
        dependent_var_dict: a dictionary from index tuple to moments/cumulants symbols, or None to use default symbols
        outer_function: logarithm to transform from moments->cumulants, exp for inverse direction
        default_prefix: if dependent_var_dict is None, this is used to construct symbols of the form prefix_i_j_k
        centralized: if True the first order moments/cumulants are set to zero
    """
    dim = len(index)
    subs_dict = {}

    def create_moment_symbol(idx):
        idx = tuple(idx)
        result_symbol = sp.Symbol(default_prefix + "_" +
                                  "_".join(["%d"] * len(idx)) % idx)
        if dependent_var_dict is not None and idx in dependent_var_dict:
            subs_dict[result_symbol] = dependent_var_dict[idx]
        return result_symbol

    zeroth_moment = create_moment_symbol((0, ) * dim)

    def outer_function_derivative(n):
        x = zeroth_moment
        return sp.diff(outer_function(x), *tuple([x] * n))

    # index (2,1,0) means differentiate twice w.r.t to first variable, and once w.r.t to second variable
    # this is transformed here into representation [0,0,1] such that each entry is one diff operation
    partition_list = []
    for i, index_component in enumerate(index):
        for j in range(index_component):
            partition_list.append(i)

    if len(partition_list) == 0:  # special case for zero index
        return fast_subs(outer_function(zeroth_moment), subs_dict)

    # implementation of Faa di Bruno's formula:
    result = 0
    for partition in __partition(partition_list):
        factor = outer_function_derivative(len(partition))
        for elements in partition:
            moment_index = [
                0,
            ] * dim
            for i in elements:
                moment_index[i] += 1
            factor *= create_moment_symbol(moment_index)
        result += factor

    if centralized:
        for i in range(dim):
            index = [0] * dim
            index[i] = 1
            result = result.subs(create_moment_symbol(index), 0)

    return fast_subs(result, subs_dict)
Esempio n. 4
0
    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
        own_definitions = set([e.lhs for e in self.main_assignments])
        other_definitions = set([e.lhs for e in other.main_assignments])
        assert len(own_definitions.intersection(other_definitions)) == 0, \
            "Cannot merge collections, since both define the same symbols"

        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
        substitution_dict = {}

        processed_other_subexpression_equations = []
        for other_subexpression_eq in other.subexpressions:
            if other_subexpression_eq.lhs in own_subexpression_symbols:
                if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
                    continue  # exact the same subexpression equation exists already
                else:
                    # different definition - a new name has to be introduced
                    new_lhs = next(self.subexpression_symbol_generator)
                    new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
                    processed_other_subexpression_equations.append(new_eq)
                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
            else:
                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))

        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
        return self.copy(self.main_assignments + processed_other_main_assignments,
                         self.subexpressions + processed_other_subexpression_equations)
Esempio n. 5
0
    def new_without_subexpressions(
        self, subexpressions_to_keep: Set[sp.Symbol] = set()
    ) -> 'AssignmentCollection':
        """Returns a new collection where all subexpressions have been inserted."""
        if len(self.subexpressions) == 0:
            return self.copy()

        subexpressions_to_keep = set(subexpressions_to_keep)

        kept_subexpressions = []
        if self.subexpressions[0].lhs in subexpressions_to_keep:
            substitution_dict = {}
            kept_subexpressions = self.subexpressions[0]
        else:
            substitution_dict = {
                self.subexpressions[0].lhs: self.subexpressions[0].rhs
            }

        subexpression = [e for e in self.subexpressions]
        for i in range(1, len(subexpression)):
            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
            if subexpression[i].lhs in subexpressions_to_keep:
                kept_subexpressions.append(subexpression[i])
            else:
                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs

        new_assignment = [
            fast_subs(eq, substitution_dict) for eq in self.main_assignments
        ]
        return self.copy(new_assignment, kept_subexpressions)
Esempio n. 6
0
 def fast_subs(self, subs_dict, skip=None):
     self.body = fast_subs(self.body, subs_dict, skip)
     if isinstance(self.start, sp.Basic):
         self.start = fast_subs(self.start, subs_dict, skip)
     if isinstance(self.stop, sp.Basic):
         self.stop = fast_subs(self.stop, subs_dict, skip)
     if isinstance(self.step, sp.Basic):
         self.step = fast_subs(self.step, subs_dict, skip)
     return self
Esempio n. 7
0
    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
        new_subexpressions = []
        subs_dict = None
        for se in self.subexpressions:
            if se.lhs == symbol:
                subs_dict = {se.lhs: se.rhs}
            else:
                new_subexpressions.append(se)
        if subs_dict is None:
            return self

        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
        return self.copy(new_eqs, new_subexpressions)
Esempio n. 8
0
def create_lbm_kernel(collision_rule,
                      src_field,
                      dst_field=None,
                      accessor=StreamPullTwoFieldsAccessor()):
    """Replaces the pre- and post collision symbols in the collision rule by field accesses.

    Args:
        collision_rule:  instance of LbmCollisionRule, defining the collision step
        src_field: field used for reading pdf values
        dst_field: field used for writing pdf values if accessor.is_inplace this parameter is ignored
        accessor: instance of PdfFieldAccessor, defining where to read and write values
                  to create e.g. a fused stream-collide kernel See 'fieldaccess.PdfFieldAccessor'

    Returns:
        LbmCollisionRule where pre- and post collision symbols have been replaced
    """
    if accessor.is_inplace:
        dst_field = src_field

    if not accessor.is_inplace and dst_field is None:
        raise ValueError(
            "For two field accessors a destination field has to be provided")

    method = collision_rule.method
    pre_collision_symbols = method.pre_collision_pdf_symbols
    post_collision_symbols = method.post_collision_pdf_symbols
    substitutions = {}

    input_accesses = accessor.read(src_field, method.stencil)
    output_accesses = accessor.write(dst_field, method.stencil)

    for (idx,
         offset), input_access, output_access in zip(enumerate(method.stencil),
                                                     input_accesses,
                                                     output_accesses):
        substitutions[pre_collision_symbols[idx]] = input_access
        substitutions[post_collision_symbols[idx]] = output_access

    result = collision_rule.new_with_substitutions(substitutions)

    if 'split_groups' in result.simplification_hints:
        new_split_groups = []
        for split_group in result.simplification_hints['split_groups']:
            new_split_groups.append(
                [fast_subs(e, substitutions) for e in split_group])
        result.simplification_hints['split_groups'] = new_split_groups

    if accessor.is_inplace:
        result = add_subexpressions_for_field_reads(result,
                                                    subexpressions=True,
                                                    main_assignments=True)

    return result
Esempio n. 9
0
def discretize_center(term, symbols_to_field_dict, dx, dim=3):
    """
    Expects term that contains given symbols and gradient components of these symbols and replaces them
    by field accesses. Gradients are replaced by centralized approximations:
    ``(upper neighbor - lower neighbor ) / ( 2*dx)``

    Args:
        term: term where symbols and gradient(symbol) should be replaced
        symbols_to_field_dict: mapping of symbols to Field
        dx: width and height of one cell
        dim: dimension

    Example:
      >>> x = sp.Symbol("x")
      >>> grad_x = grad(x, dim=3)
      >>> term = x * grad_x[0]
      >>> term
      x*x^Delta^0
      >>> f = Field.create_generic('f', spatial_dimensions=3)
      >>> expected_output = f[0, 0, 0] * (-f[-1, 0, 0]/2 + f[1, 0, 0]/2)
      >>> sp.simplify(discretize_center(term, { x: f }, dx=1, dim=3) - expected_output)
      0
    """
    substitutions = {}
    for symbols, field in symbols_to_field_dict.items():
        if not hasattr(symbols, "__getitem__"):
            symbols = [symbols]
        g = grad(symbols, dim)
        substitutions.update(
            {symbol: field(i)
             for i, symbol in enumerate(symbols)})
        for d in range(dim):
            up, down = __up_down_offsets(d, dim)
            substitutions.update({
                g[d][i]: (field[up](i) - field[down](i)) / dx / 2
                for i in range(len(symbols))
            })
    return fast_subs(term, substitutions)
Esempio n. 10
0
def update_rule_with_push_boundaries(collision_rule,
                                     field,
                                     boundary_spec,
                                     streaming_pattern='pull',
                                     timestep=Timestep.BOTH):
    method = collision_rule.method
    accessor = get_accessor(streaming_pattern, timestep)
    loads = [
        Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols,
                                         accessor.read(field, method.stencil))
    ]
    stores = [
        Assignment(a, b) for a, b in zip(accessor.write(field, method.stencil),
                                         method.post_collision_pdf_symbols)
    ]

    result = collision_rule.copy()
    result.subexpressions = loads + result.subexpressions
    result.main_assignments += stores
    for direction, boundary in boundary_spec.items():
        cond = boundary_conditional(boundary, direction, streaming_pattern,
                                    timestep, method, field)
        result.main_assignments.append(cond)

    if 'split_groups' in result.simplification_hints:
        substitutions = {
            b: a
            for a, b in zip(accessor.write(field, method.stencil),
                            method.post_collision_pdf_symbols)
        }
        new_split_groups = []
        for split_group in result.simplification_hints['split_groups']:
            new_split_groups.append(
                [fast_subs(e, substitutions) for e in split_group])
        result.simplification_hints['split_groups'] = new_split_groups

    return result
Esempio n. 11
0
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
                                                strided, keep_loop_stop, assume_sufficient_line_padding):
    """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
    all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
    inner_loops = [n for n in all_loops if n.is_innermost_loop]
    zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops}

    for loop_node in inner_loops:
        loop_range = loop_node.stop - loop_node.start

        # cut off loop tail, that is not a multiple of four
        if keep_loop_stop:
            pass
        elif assume_aligned and assume_sufficient_line_padding:
            loop_range = loop_node.stop - loop_node.start
            new_stop = loop_node.start + modulo_ceil(loop_range, vector_width)
            loop_node.stop = new_stop
        else:
            cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
            loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)]
            assert len(loop_nodes) in (0, 1, 2)  # 2 for main and tail loop, 1 if loop range divisible by vector width
            if len(loop_nodes) == 0:
                continue
            loop_node = loop_nodes[0]

        # Find all array accesses (indexed) that depend on the loop counter as offset
        loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
        substitutions = {}
        successful = True
        for indexed in loop_node.atoms(sp.Indexed):
            base, index = indexed.args
            if loop_counter_symbol in index.atoms(sp.Symbol):
                loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
                aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
                stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
                if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()):
                    successful = False
                    break
                typed_symbol = base.label
                assert type(typed_symbol.dtype) is PointerType, \
                    f"Type of access is {typed_symbol.dtype}, {indexed}"

                vec_type = VectorType(typed_symbol.dtype.base_type, vector_width)
                use_aligned_access = aligned_access and assume_aligned
                nontemporal = False
                if hasattr(indexed, 'field'):
                    nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
                                                              stride if strided else 1)
                if nontemporal:
                    # insert NontemporalFence after the outermost loop
                    parent = loop_node.parent
                    while type(parent.parent.parent) is not ast.KernelFunction:
                        parent = parent.parent
                    parent.parent.insert_after(NontemporalFence(), parent, if_not_exists=True)
                    # insert CachelineSize at the beginning of the kernel
                    parent.parent.insert_front(CachelineSize(), if_not_exists=True)
        if not successful:
            warnings.warn("Could not vectorize loop because of non-consecutive memory access")
            continue

        loop_node.step = vector_width
        loop_node.subs(substitutions)
        vector_int_width = ast_node.instruction_set['intwidth']
        vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
            + cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
                        VectorType(loop_counter_symbol.dtype, vector_int_width))

        fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))

        mask_conditionals(loop_node)

        from pystencils.rng import RNGBase
        substitutions = {}
        for rng in loop_node.atoms(RNGBase):
            new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
                                  for s in rng.result_symbols]
            substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
            rng._symbols_defined = set(new_result_symbols)
        fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
Esempio n. 12
0
 def fast_subs(self, subs_dict, skip):
     rng = copy.deepcopy(self)
     rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args]
     return rng
Esempio n. 13
0
 def subs(self, subs_dict):
     self.lhs = fast_subs(self.lhs, subs_dict)
     self.rhs = fast_subs(self.rhs, subs_dict)
Esempio n. 14
0
 def fast_subs(self, subs_dict, skip=None):
     self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
     return self
Esempio n. 15
0
    def forward_transform(self, pdf_symbols, simplification=True, subexpression_base='sub_f_to_k',
                          return_monomials=False):
        r"""Returns equations for polynomial central moments, computed from pre-collision populations
        through a cascade of three steps.

        First, the monomial raw moment vector :math:`\mathbf{m}` is computed using the raw-moment
        chimera transform (see `lbmpy.moment_transforms.PdfsToMomentsByChimeraTransform`). Then, the
        monomial shift matrix :math:`N` provided by `lbmpy.moments.set_up_shift_matrix` is used to compute
        the monomial central moment vector as :math:`\mathbf{\kappa} = N \mathbf{m}`. Lastly, the polynomial
        central moments are computed using the polynomialization matrix as :math:`\mathbf{K} = P \mathbf{\kappa}`.

        **Conserved Quantity Equations**

        If given, this transform absorbs the conserved quantity equations and simplifies them
        using the raw moment equations, if simplification is enabled.


        **Simplification**

        If simplification is enabled, the absorbed conserved quantity equations are - if possible - 
        rewritten using the monomial symbols. If the conserved quantities originate somewhere else
        than in the lower-order moments (like from an external field), they are not affected by this
        simplification.

        The relations between conserved quantities and raw moments are used to simplify the equations
        obtained from the shift matrix. Further, these equations are simplified by recursively inserting
        lower-order moments into equations for higher-order moments.

         **De-Aliasing**

        If more than :math:`q` monomial moments are extracted from the polynomial set, they
        are de-aliased and reduced to a set of only :math:`q` moments using the same rules
        as for raw moments. For polynomialization, a special reduced matrix :math:`\tilde{P}`
        is used, which is computed using `lbmpy.moments.central_moment_reduced_monomial_to_polynomial_matrix`.


        Args:
            pdf_symbols: List of symbols that represent the pre-collision populations
            simplification: Simplification specification. See :class:`AbstractMomentTransform`
            subexpression_base: The base name used for any subexpressions of the transformation.
            return_monomials: Return equations for monomial moments. Use only when specifying 
                              ``moment_exponents`` in constructor!

        """
        simplification = self._get_simp_strategy(simplification, 'forward')

        raw_moment_base = self.raw_moment_transform.mono_base_pre
        central_moment_base = self.mono_base_pre

        mono_rm_symbols = self.raw_moment_transform.pre_collision_monomial_symbols
        mono_cm_symbols = self.pre_collision_monomial_symbols

        rm_ac = self.raw_moment_transform.forward_transform(pdf_symbols, simplification=False, return_monomials=True)
        cq_symbols_to_moments = self.raw_moment_transform.get_cq_to_moment_symbols_dict(raw_moment_base)
        rm_to_cm_vec = self.shift_matrix * sp.Matrix(mono_rm_symbols)

        cq_subs = dict()
        if simplification:
            from lbmpy.methods.momentbased.momentbasedsimplifications import (
                substitute_moments_in_conserved_quantity_equations)
            rm_ac = substitute_moments_in_conserved_quantity_equations(rm_ac)

            #   Compute replacements for conserved moments in terms of the CQE
            rm_asm_dict = rm_ac.main_assignments_dict
            for cq_sym, moment_sym in cq_symbols_to_moments.items():
                cq_eq = rm_asm_dict[cq_sym]
                solutions = sp.solve(cq_eq - cq_sym, moment_sym)
                if len(solutions) > 0:
                    cq_subs[moment_sym] = solutions[0]

            rm_to_cm_vec = fast_subs(rm_to_cm_vec, cq_subs)

        rm_to_cm_dict = {cm: rm for cm, rm in zip(mono_cm_symbols, rm_to_cm_vec)}

        if simplification:
            rm_to_cm_dict = self._simplify_raw_to_central_moments(
                rm_to_cm_dict, self.moment_exponents, raw_moment_base, central_moment_base)
            rm_to_cm_dict = self._undo_remaining_cq_subexpressions(rm_to_cm_dict, cq_subs)

        subexpressions = rm_ac.all_assignments

        if return_monomials:
            main_assignments = [Assignment(lhs, rhs) for lhs, rhs in rm_to_cm_dict.items()]
        else:
            subexpressions += [Assignment(lhs, rhs) for lhs, rhs in rm_to_cm_dict.items()]
            poly_eqs = self.mono_to_poly_matrix * sp.Matrix(mono_cm_symbols)
            main_assignments = [Assignment(m, v) for m, v in zip(self.pre_collision_symbols, poly_eqs)]

        symbol_gen = SymbolGen(subexpression_base)
        ac = AssignmentCollection(main_assignments=main_assignments, subexpressions=subexpressions,
                                  subexpression_symbol_generator=symbol_gen)

        if simplification:
            ac = simplification.apply(ac)
        return ac
Esempio n. 16
0
def discretize_staggered(term,
                         symbols_to_field_dict,
                         coordinate,
                         coordinate_offset,
                         dx,
                         dim=3):
    """
    Expects term that contains given symbols and gradient components of these symbols and replaces them
    by field accesses. Gradients in coordinate direction  are replaced by staggered version at cell boundary.
    Symbols themselves and gradients in other directions are replaced by interpolated version at cell face.

    Args:
        term: input term where symbols and gradients are replaced
        symbols_to_field_dict: mapping of symbols to Field
        coordinate: id for coordinate (0 for x, 1 for y, ... ) defining cell boundary.
                    Only gradients in this direction are replaced e.g. if symbol^Delta^coordinate
        coordinate_offset: either +1 or -1 for upper or lower face in coordinate direction
        dx: width and height of one cell
        dim: dimension

    Examples:
      Discretizing at right/east face of cell i.e. coordinate=0, offset=1)
      >>> x, dx = sp.symbols("x dx")
      >>> grad_x = grad(x, dim=3)
      >>> term = x * grad_x[0]
      >>> term
      x*x^Delta^0
      >>> f = Field.create_generic('f', spatial_dimensions=3)
      >>> discretize_staggered(term, symbols_to_field_dict={ x: f}, dx=dx, coordinate=0, coordinate_offset=1, dim=3)
      (-f_C + f_E)*(f_C/2 + f_E/2)/dx
    """
    assert coordinate_offset == 1 or coordinate_offset == -1
    assert 0 <= coordinate < dim

    substitutions = {}
    for symbols, field in symbols_to_field_dict.items():
        if not hasattr(symbols, "__getitem__"):
            symbols = [symbols]

        offset = [0] * dim
        offset[coordinate] = coordinate_offset
        offset = np.array(offset, dtype=np.int)

        gradient = grad(symbols)[coordinate]
        substitutions.update({
            s: (field[offset](i) + field(i)) / 2
            for i, s in enumerate(symbols)
        })
        substitutions.update({
            g: (field[offset](i) - field(i)) / dx * coordinate_offset
            for i, g in enumerate(gradient)
        })
        for d in range(dim):
            if d == coordinate:
                continue
            up, down = __up_down_offsets(d, dim)
            for i, s in enumerate(symbols):
                center_grad = (field[up](i) - field[down](i)) / (2 * dx)
                neighbor_grad = (field[up + offset](i) -
                                 field[down + offset](i)) / (2 * dx)
                substitutions[grad(s)[d]] = (center_grad + neighbor_grad) / 2

    return fast_subs(term, substitutions)
Esempio n. 17
0
 def equilibrium_exprs(self):
     subs_dict = {rr: 1 for rr in self._free_relaxation_rates}
     subs_dict.update({rr: 1 for rr in self._fixed_relaxation_rates})
     update_equations = self._collisionRule.main_assignments
     return [fast_subs(eq.rhs, subs_dict) for eq in update_equations]