def expr_to_contract_convert(self,expr): """ Takes a DSL expr containing dsl_* objects and substitutes any objects which match keys in contraction_info.contraction_convert_dict for the corresponding values in the dictionary. """ assert self.contraction_info() != None,\ "a contraction_info object must be attached to converter for "+\ "expr_to_contract_convert to function" contract_expr = expr_deepish_copy( expr ) if not isinstance( contract_expr, dsl_binop ) and\ not isinstance( contract_expr, dsl_unop ): # If expr is not a dsl_* expression tree, simply check dictionary for # matching object try: return self.contraction_info().contraction_convert_dict[ contract_expr ] except KeyError: # If no match found, return the original object unchanged return expr else: for k, v in self.contraction_info().contraction_convert_dict.items(): contract_expr = expr_find_and_replace( contract_expr, k, v ) return contract_expr
def expr_to_src_convert(self,expr): """ Takes a DSL expr containing dsl_* objects and replaces the dsl_* objects with corresponding src_dsl_* wrappers. Also removes any terms in the src_dsl_* expression which unambiguously evaluate to zero. """ if not isinstance(expr,dsl_binop) and\ not isinstance(expr,dsl_unop): # If expr is not an expression, but just a single object, call faster # routine return self.dsl_to_src_dsl_convert(expr) wintegral = self._wintegral direction = self._direction #gen = self._generator src_expr = expr_deepish_copy( expr ) # list dsl_integral objects in expr dintegral_list = expr_list_objects( src_expr, dsl_integral ) # list dsl_index objects in expr dindex_list = expr_list_objects( src_expr, dsl_index ) # list dsl_variable objects in expr dvariable_list = expr_list_objects( src_expr, dsl_variable ) # list dsl_value objects in expr dvalue_list = expr_list_objects( src_expr, dsl_value ) # list dsl_function objects in expr dfunction_list = expr_list_objects( src_expr, dsl_function ) # list dsl_pointer objects in expr dpointer_list = expr_list_objects( src_expr, dsl_pointer ) # additionally, check for general_array objects and list these general_array_list = expr_list_objects( src_expr, general_array ) # swap dsl_integral objects for src_dsl_integral objects for dintegral in dintegral_list: if wintegral != None: # if wintegral is defined... # all dsl_integral objects should be the same class as the integral_wrapper # object with this method assert dintegral.name() == wintegral.integral().name(),\ "all dsl_integral objects in expr should be of same class as "+\ "the integral class being generated." #sintegral = src_dsl_integral( \ # dintegral, wintegral, direction, gen ) sintegral = src_dsl_integral( dintegral, self ) src_expr = expr_find_and_replace( src_expr, dintegral, sintegral ) # swap dsl_* objects for src_dsl_* objects # src_dsl_{index,variable,value} objects have the same interface # for dobj_list, src_obj_type in [\ # ( dindex_list, src_dsl_index ),\ # ( dvariable_list, src_dsl_variable ),\ # ( dvalue_list, src_dsl_value ),\ # ( dfunction_list, src_dsl_function ), \ # ( dpointer_list, src_dsl_pointer ) ]: # for dobj in dobj_list: # sobj = src_obj_type( dobj, self ) # src_expr = expr_find_and_replace( src_expr, dobj, sobj ) dobj_list = dindex_list + dvariable_list + dfunction_list + dpointer_list + \ general_array_list for dobj in dobj_list: sobj = self.dsl_to_src_dsl_convert( dobj ) src_expr = expr_find_and_replace( src_expr, dobj, sobj ) # swap general_array objects for src_general_array objects # for ga in general_array_list: # src_ga = src_general_array( ga, self ) # src_expr = expr_find_and_replace( src_expr, ga, src_ga ) # Remove terms that unambiguously evaluate to zero src_expr = expr_simplify( src_expr ) return src_expr