    def __init__(self, values, model=None):

        from dolang.symbolic import sanitize, stringify

        exogenous = model.symbols['exogenous']
        states = model.symbols['states']
        controls = model.symbols['controls']
        parameters = model.symbols['parameters']

        preamble = dict([(s, values[s]) for s in values.keys()
                         if s not in controls])
        equations = [values[s] for s in controls]

        variables = exogenous + states + controls + [*preamble.keys()]

        preamble_str = dict()

        for k in [*preamble.keys()]:
            v = preamble[k]
            if '(' not in k:
                vv = f'{k}(0)'
                vv = k

            preamble_str[stringify(vv)] = stringify(sanitize(v, variables))

        # let's reorder the preamble
        from dolang.triangular_solver import get_incidence, triangular_solver
        incidence = get_incidence(preamble_str)
        sol = triangular_solver(incidence)
        kk = [*preamble_str.keys()]
        preamble_str = dict([(kk[k], preamble_str[kk[k]]) for k in sol])

        equations = [
            dolang.symbolic.sanitize(eq, variables) for eq in equations
        equations_strings = [
            dolang.stringify(eq, variables) for eq in equations

        args = dict([('m', [(e, 0) for e in exogenous]),
                     ('s', [(e, 0) for e in states]),
                     ('p', [e for e in parameters])])

        args = dict([(k, [stringify_symbol(e) for e in v])
                     for k, v in args.items()])

        targets = [stringify_symbol((e, 0)) for e in controls]

        eqs = dict([(targets[i], eq)
                    for i, eq in enumerate(equations_strings)])

        fff = FlatFunctionFactory(preamble_str, eqs, args, 'custom_dr')

        fun, gufun = make_method_from_factory(fff)

        self.p = model.calibration['parameters']
        self.exo_grid = model.exogenous.discretize()  # this is never used
        self.endo_grid = model.get_grid()
        self.gufun = gufun
def test_stringify():

    from dolang.symbolic import parse_string

    s = "sin(a(1) + b + a(0) + f(-1) + f(4) + a(1))"
    enes = stringify(s)
    assert enes == "sin(a__1_ + b_ + a__0_ + f_m1_ + f__4_ + a__1_)"
    def ℰ(self):

        if self.__equilibrium__ is None:
            if self.features["with-aggregate-states"]:
                arguments_ = {
                    "e": [(e, 0) for e in self.model.symbols["exogenous"]],
                    "s": [(e, 0) for e in self.model.symbols["states"]],
                    "x": [(e, 0) for e in self.model.symbols["controls"]],
                    "m": [(e, 0) for e in self.symbols["exogenous"]],
                    "S": [(e, 0) for e in self.symbols["states"]],
                    "X": [(e, 0) for e in self.symbols["aggregate"]],
                    "m_1": [(e, 1) for e in self.symbols["exogenous"]],
                    "S_1": [(e, 1) for e in self.symbols["states"]],
                    "X_1": [(e, 1) for e in self.symbols["aggregate"]],
                    "p": self.symbols["parameters"],
                arguments_ = {
                    "e": [(e, 0) for e in self.model.symbols["exogenous"]],
                    "s": [(e, 0) for e in self.model.symbols["states"]],
                    "x": [(e, 0) for e in self.model.symbols["controls"]],
                    "m": [(e, 0) for e in self.symbols["exogenous"]],
                    "X": [(e, 0) for e in self.symbols["aggregate"]],
                    "m_1": [(e, 1) for e in self.symbols["exogenous"]],
                    "X_1": [(e, 1) for e in self.symbols["aggregate"]],
                    "p": self.symbols["parameters"],

            vars = sum([[e[0] for e in h]
                        for h in [*arguments_.values()][:-1]], [])

            arguments = {
                k: [dolang.symbolic.stringify_symbol(e) for e in v]
                for k, v in arguments_.items()

            preamble = {}  # for now

            from dolang.symbolic import sanitize, stringify

            eqs = parse_string(["equilibrium"],
            eqs = sanitize(eqs, variables=vars)
            eqs = stringify(eqs)
            content = {}
            for i, eq in enumerate(eqs.children):
                lhs, rhs = eq.children
                content[f"eq_{i}"] = "({1})-({0})".format(
                    str_expression(lhs), str_expression(rhs))

            fff = FlatFunctionFactory(preamble, content, arguments,
            _, gufun = dolang.function_compiler.make_method_from_factory(
                fff, debug=self.debug)
            from dolang.vectorize import standard_function

            self.__equilibrium__ = standard_function(gufun, len(content))

        return self.__equilibrium__
def test_time_shift():

    from dolang.symbolic import time_shift, stringify_parameter

    e = "sin(a(1) + b + a(0) + f(-1) + f(4) + a(1))"

    enes = stringify(time_shift(e, +1))
    assert (enes) == "sin(a__2_ + b_ + a__1_ + f__0_ + f__5_ + a__2_)"
def test_time_shift():

    from dolang.symbolic import parse_string
    e = parse_string('sin(a(1) + b + a(0) + f(-(1)) + f(4) + a(1))')
    enes = stringify(e, variables=['a', 'f'])
    assert (
        to_source(enes) == "sin(a__1_ + b_ + a__0_ + f_m1_ + f__4_ + a__1_)")
    def projection(self):  # , m: 'n_e', y: "n_y", p: "n_p"):

        # TODO:
        # behaves in a very misleading way if wrong number of argument is supplied
        #  if no aggregate states, projection(m,x) (instead of projection(m,x,p)) returns zeros

        if self.__projection__ is None:
            if self.features["with-aggregate-states"]:
                arguments_ = {
                    "m": [(e, 0) for e in self.symbols["exogenous"]],
                    "S": [(e, 0) for e in self.symbols["states"]],
                    "X": [(e, 0) for e in self.symbols["aggregate"]],
                    "p": self.symbols["parameters"],
                arguments_ = {
                    "m": [(e, 0) for e in self.symbols["exogenous"]],
                    "X": [(e, 0) for e in self.symbols["aggregate"]],
                    "p": self.symbols["parameters"],

            vars = sum([[e[0] for e in h]
                        for h in [*arguments_.values()][:-1]], [])

            arguments = {
                k: [dolang.symbolic.stringify_symbol(e) for e in v]
                for k, v in arguments_.items()

            preamble = {}  # for now

            from dolang.symbolic import sanitize, stringify

            eqs = parse_string(["projection"],
            eqs = sanitize(eqs, variables=vars)
            eqs = stringify(eqs)

            content = {}
            for eq in eqs.children:
                lhs, rhs = eq.children
                content[str_expression(lhs)] = str_expression(rhs)

            fff = FlatFunctionFactory(preamble, content, arguments,
            _, gufun = dolang.function_compiler.make_method_from_factory(
                fff, debug=self.debug)

            from dolang.vectorize import standard_function

            self.__projection__ = standard_function(gufun, len(content))

        return self.__projection__
    def 𝒢(self):

        if (self.__transition__ is
                None) and self.features["with-aggregate-states"]:
            arguments_ = {
                "m_m1": [(e, -1) for e in self.symbols["exogenous"]],
                "S_m1": [(e, -1) for e in self.symbols["states"]],
                "X_m1": [(e, -1) for e in self.symbols["aggregate"]],
                "m": [(e, 0) for e in self.symbols["exogenous"]],
                "p": self.symbols["parameters"],

            vars = sum([[e[0] for e in h]
                        for h in [*arguments_.values()][:-1]], [])

            arguments = {
                k: [dolang.symbolic.stringify_symbol(e) for e in v]
                for k, v in arguments_.items()

            preamble = {}  # for now

            from dolang.symbolic import (

            eqs = parse_string(["transition"],
            eqs = sanitize(eqs, variables=vars)
            eqs = stringify(eqs)

            content = {}
            for i, eq in enumerate(eqs.children):
                lhs, rhs = eq.children
                content[str_expression(lhs)] = str_expression(rhs)

            from dolang.factory import FlatFunctionFactory

            fff = FlatFunctionFactory(preamble, content, arguments,

            _, gufun = dolang.function_compiler.make_method_from_factory(
                fff, debug=self.debug)

            from dolang.vectorize import standard_function

            self.__transition__ = standard_function(gufun, len(content))

        return self.__transition__
def eval_formula(expr: str, dataframe=None, context=None):
    expr: string
        Symbolic expression to evaluate.
        Example: `k(1)-delta*k(0)-i`
    table: (optional) pandas dataframe
        Each column is a time series, which can be indexed with dolo notations.
    context: dict or CalibrationDict

    print("Evaluating: {}".format(expr))
    if context is None:
        dd = {}  # context dictionary
    elif isinstance(context, CalibrationDict):
        dd = context.flat.copy()
        dd = context.copy()

    # compat since normalize form for parameters doesn't match calib dict.
    for k in [*dd.keys()]:
        dd[stringify_symbol(k)] = dd[k]

    expr_ast = parse_string(expr).value
    variables = list_variables(expr_ast)
    nexpr = stringify(expr_ast)

    dd['log'] = log
    dd['exp'] = exp

    if dataframe is not None:

        import pandas as pd
        for (k, t) in variables:
            dd[stringify_symbol((k, t))] = dataframe[k].shift(t)
        dd['t'] = pd.Series(dataframe.index, index=dataframe.index)

    expr = to_source(nexpr)
    res = eval(expr, dd)

    return res
def parse_equation(eq_string, vars, substract_lhs=True, to_sympy=False):

    eq = eq_string.split("|")[0]  # ignore complentarity constraints

    if "==" not in eq:
        eq = eq.replace("=", "==")

    expr = ast.parse(eq).body[0].value
    expr_std = stringify(expr, variables=vars)

    if isinstance(expr_std, Compare):
        lhs = expr_std.left
        rhs = expr_std.comparators[0]
        if substract_lhs:
            expr_std = BinOp(left=rhs, right=lhs, op=Sub())
            if to_sympy:
                return [ast_to_sympy(lhs), ast_to_sympy(rhs)]
            return [lhs, rhs]

    if to_sympy:
        return ast_to_sympy(expr_std)
        return expr_std
def parse_equation(eq_string, vars, substract_lhs=True, to_sympy=False):

    eq = eq_string.split('|')[0]  # ignore complentarity constraints

    if '==' not in eq:
        eq = eq.replace('=', '==')

    expr = ast.parse(eq).body[0].value
    expr_std = stringify(expr, variables=vars)

    if isinstance(expr_std, Compare):
        lhs = expr_std.left
        rhs = expr_std.comparators[0]
        if substract_lhs:
            expr_std = BinOp(left=rhs, right=lhs, op=Sub())
            if to_sympy:
                return [ast_to_sympy(lhs), ast_to_sympy(rhs)]
            return [lhs, rhs]

    if to_sympy:
        return ast_to_sympy(expr_std)
        return expr_std
    equations = [compat(eq) for eq in equations]
    equations = ['{} - ({})'.format(*str.split(eq, '==')) for eq in equations]
    equations = [parse_string(e) for e in equations]

with timeit("stringify equations"):

    all_variables = [(v, 1) for v in model['symbols']['variables']] + \
                    [(v, 0) for v in model['symbols']['variables']]  + \
                    [(v, -1) for v in model['symbols']['variables']]  + \
                    [(v, 0) for v in model['symbols']['shocks']]
    all_vnames = [e[0] for e in all_variables]
    all_constants = model['symbols']['parameters']

    # here comes the costly step
    equations_stringified = [
        stringify(e, variables=all_vnames) for e in equations
    equations_stringified_strings = [
        to_source(e) for e in equations_stringified
    variables_stringified_strings = [
        stringify_variable(e) for e in all_variables


with timeit("Sympify equations"):
    equations_stringified_sympy = [
        sympy.sympify(e) for e in equations_stringified_strings
def get_factory(model, eq_type: str, tshift: int = 0):

    from dolo.compiler.model import decode_complementarity

    from import recipes
    from dolang.symbolic import stringify, stringify_symbol

    equations = model.equations

    if eq_type == "auxiliary":
        eqs = [('{}({})'.format(s, 0)) for s in model.symbols['auxiliaries']]
        specs = {
            'eqs': [['exogenous', 0, 'm'], ['states', 0, 's'],
                    ['controls', 0, 'x'], ['parameters', 0, 'p']]
        eqs = equations[eq_type]
        if eq_type in ('controls_lb', 'controls_ub'):
            specs = {
            specs = recipes['dtcc']['specs'][eq_type]

    specs = shift_spec(specs, tshift=tshift)

    preamble_tshift = set([s[1] for s in specs['eqs'] if s[0] == 'states'])
    preamble_tshift = preamble_tshift.intersection(
        set([s[1] for s in specs['eqs'] if s[0] == 'controls']))

    args = []
    for sg in specs['eqs']:
        if sg[0] == 'parameters':
            args.append([s for s in model.symbols["parameters"]])
            args.append([(s, sg[1]) for s in model.symbols[sg[0]]])
    args = [[stringify_symbol(e) for e in vg] for vg in args]

    arguments = dict(zip([sg[2] for sg in specs['eqs']], args))

    # temp
    eqs = [eq.replace("==","=").replace("=","==") for eq in eqs]

    if 'target' in specs:
        sg = specs['target']
        targets = [(s, sg[1]) for s in model.symbols[sg[0]]]
        eqs = [eq.split('==')[1] for eq in eqs]
        eqs = [("({1})-({0})".format(*eq.split('==')) if '==' in eq else eq)
               for eq in eqs]
        targets = [('out{}'.format(i), 0) for i in range(len(eqs))]

    eqs = [str.strip(eq) for eq in eqs]
    eqs = [dolang.parse_string(eq) for eq in eqs]
    es = ExpressionSanitizer(model.variables)
    eqs = [es.visit(eq) for eq in eqs]

    eqs = [time_shift(eq, tshift) for eq in eqs]
    eqs = [stringify(eq) for eq in eqs]
    eqs = [dolang.to_source(eq) for eq in eqs]

    targets = [stringify_symbol(e) for e in targets]

    # sanitize defs ( should be )
    defs = dict()
    for k in model.definitions:
        if '(' not in k:
            s = "{}(0)".format(k)
            val = model.definitions[k]
            val = es.visit(dolang.parse_string(val))
            for t in preamble_tshift:
                s = stringify_symbol((k, t))
                vv = stringify(time_shift(val, t))
                defs[s] = dolang.to_source(vv)

    preamble = reorder_preamble(defs)

    eqs = dict(zip(targets, eqs))
    ff = FlatFunctionFactory(preamble, eqs, arguments, eq_type)

    return ff
def make_method(equations,

    compat = lambda s: s.replace("^", "**").replace('==', '=').replace(
        '=', '==')
    equations = [compat(eq) for eq in equations]

    if isinstance(arguments, list):
        arguments = OrderedDict([('arg_{}'.format(i), k)
                                 for i, k in enumerate(arguments)])

    ## replace = by ==
    known_variables = [a[0] for a in sum(arguments.values(), [])]
    known_definitions = [a for a in definitions.keys()]
    known_constants = [a[0] for a in constants]
    all_variables = known_variables + known_definitions
    known_functions = []
    known_constants = []

    if targets is not None:
        all_variables.extend([o[0] for o in targets])
        targets = [stringify(o) for o in targets]
        targets = ['_out_{}'.format(n) for n in range(len(equations))]

    all_symbols = all_variables  # + known_constants

    equations = [parse(eq) for eq in equations]
    definitions = {k: parse(v) for k, v in definitions.items()}

    defs_incidence = {}
    for sym, val in definitions.items():
        lvars = list_variables(val, vars=known_definitions)
                        0)] = [v for v in lvars if v[0] in known_definitions]
    # return defs_incidence

    equations_incidence = {}
    to_be_defined = set([])
    for i, eq in enumerate(equations):
        cn = CountNames(all_variables, known_functions, known_constants)
        equations_incidence[i] = cn.variables
        to_be_defined = to_be_defined.union(
            [a for a in cn.variables if a[0] in known_definitions])

    deps = []
    for tv in to_be_defined:
        ndeps = get_deps(defs_incidence, tv)
    deps = [d for d in unique(deps)]

    new_definitions = OrderedDict()
    for k in deps:
        val = definitions[k[0]]
        nval = time_shift(val, k[1], all_variables)
        new_definitions[stringify(k)] = normalize(nval, variables=all_symbols)

    new_equations = []

    for n, eq in enumerate(equations):
        d = match(parse("_x == _y"), eq)
        if d is not False:
            lhs = d['_x']
            rhs = d['_y']
            if rhs_only:
                val = rhs
                val = ast.BinOp(left=rhs, op=Sub(), right=lhs)
            val = eq
        new_equations.append(normalize(val, variables=all_symbols))

    # preambleIndex(Num(x))
    preamble = []
    for i, (arg_group_name, arg_group) in enumerate(arguments.items()):
        for pos, t in enumerate(arg_group):
            sym = stringify(t)
            rhs = Subscript(value=Name(id=arg_group_name, ctx=Load()),
            val = Assign(targets=[Name(id=sym, ctx=Store())], value=rhs)

    for pos, p in enumerate(constants):
        sym = stringify(p)
        rhs = Subscript(value=Name(id='p', ctx=Load()),
        val = Assign(targets=[Name(id=sym, ctx=Store())], value=rhs)

    # now construct the function per se
    body = []
    for k, v in new_definitions.items():
        line = Assign(targets=[Name(id=k, ctx=Store())], value=v)

    for n, neq in enumerate(new_equations):
        line = Assign(targets=[Name(id=targets[n], ctx=Store())],

    for n, neq in enumerate(new_equations):
        line = Assign(targets=[
            Subscript(value=Name(id='out', ctx=Load()),
                      value=Name(id=targets[n], ctx=Load()))

    from ast import arg, FunctionDef, Module
    from ast import arguments as ast_arguments

    f = FunctionDef(
        args=ast_arguments(args=[arg(arg=a) for a in arguments.keys()] +
                           [arg(arg='p'), arg(arg='out')],
        body=preamble + body,

    mod = Module(body=[f])
    mod = ast.fix_missing_locations(mod)
    return mod
def get_factory(model, eq_type: str, tshift: int = 0):

    from dolo.compiler.model import decode_complementarity

    from import recipes
    from dolang.symbolic import stringify, stringify_symbol

    equations = model.equations

    if eq_type == "auxiliary":
        eqs = ["{}".format(s) for s in model.symbols["auxiliaries"]]
        specs = {
            "eqs": [
                ["exogenous", 0, "m"],
                ["states", 0, "s"],
                ["controls", 0, "x"],
                ["parameters", 0, "p"],
        eqs = equations[eq_type]
        if eq_type in ("arbitrage_lb", "arbitrage_ub"):
            specs = {
            specs = recipes["dtcc"]["specs"][eq_type]

    specs = shift_spec(specs, tshift=tshift)

    preamble_tshift = set([s[1] for s in specs["eqs"] if s[0] == "states"])
    preamble_tshift = preamble_tshift.intersection(
        set([s[1] for s in specs["eqs"] if s[0] == "controls"]))

    args = []
    for sg in specs["eqs"]:
        if sg[0] == "parameters":
            args.append([s for s in model.symbols["parameters"]])
            args.append([(s, sg[1]) for s in model.symbols[sg[0]]])
    args = [[stringify_symbol(e) for e in vg] for vg in args]

    arguments = dict(zip([sg[2] for sg in specs["eqs"]], args))

    # temp
    eqs = [eq.split("⟂")[0].strip() for eq in eqs]

    if "target" in specs:
        sg = specs["target"]
        targets = [(s, sg[1]) for s in model.symbols[sg[0]]]
        eqs = [eq.split("=")[1] for eq in eqs]
        eqs = [("({1})-({0})".format(*eq.split("=")) if "=" in eq else eq)
               for eq in eqs]
        targets = [("out{}".format(i), 0) for i in range(len(eqs))]

    eqs = [str.strip(eq) for eq in eqs]

    eqs = [dolang.parse_string(eq) for eq in eqs]
    es = Sanitizer(variables=model.variables)
    eqs = [es.transform(eq) for eq in eqs]

    eqs = [time_shift(eq, tshift) for eq in eqs]

    eqs = [stringify(eq) for eq in eqs]

    eqs = [str_expression(eq) for eq in eqs]

    targets = [stringify_symbol(e) for e in targets]

    # sanitize defs ( should be )
    defs = dict()

    for k in model.definitions:
        val = model.definitions[k]
        # val = es.transform(dolang.parse_string(val))
        for t in preamble_tshift:
            s = stringify(time_shift(k, t))
            if isinstance(val, str):
                vv = stringify(time_shift(val, t))
                vv = str(val)
            defs[s] = vv

    preamble = reorder_preamble(defs)

    eqs = dict(zip(targets, eqs))
    ff = FlatFunctionFactory(preamble, eqs, arguments, eq_type)

    return ff
    equations = [compat(eq) for eq in equations]
    equations = ['{} - ({})'.format(*str.split(eq,'==')) for eq in equations]
    equations = [parse_string(e) for e in equations]

with timeit("stringify equations"):

    all_variables = [(v, 1) for v in model['symbols']['variables']] + \
                    [(v, 0) for v in model['symbols']['variables']]  + \
                    [(v, -1) for v in model['symbols']['variables']]  + \
                    [(v, 0) for v in model['symbols']['shocks']]
    all_vnames = [e[0] for e in all_variables]
    all_constants =  model['symbols']['parameters']

    # here comes the costly step
    equations_stringified = [stringify(e, variables=all_vnames) for e in equations]
    equations_stringified_strings = [to_source(e) for e in equations_stringified]
    variables_stringified_strings = [stringify_variable(e) for e in all_variables]

stringify_variable( all_variables[-3] )

with timeit("Sympify equations"):
    equations_stringified_sympy = [sympy.sympify(e) for e in equations_stringified_strings]

with timeit("Compute jacobian (sympy)"):
    jac = []
    for eq in equations_stringified_strings:
        line = []