def simulate_household_ss(self,D=None,do_print=False):
        """ gateway for simulating the model towards the steady state"""
        
        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol
            sim = model.sim        

            # a. intial guess
            sim.D[:,:] = 0.0
            sim.D[:,0] = par.e_ergodic # start with a = 0.0 for everybody
            
            # b. simulate
            it = simulate_ss(par,sol,sim)

            if do_print:
                if it >= 0:
                    print(f'household problem simulated in {elapsed(t0)} [{it} iterations]')
                else:
                    print(f'household problem simulation did not converge')

        return (it >= 0)
    def solve_household_path(self,path_r,path_w=None,do_print=False):
        """ gateway for solving the model along price path (with optimal update of path for w) """

        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol

            # a. create (or re-create) grids
            self.create_grids()

            # c. solve
            for t in reversed(range(par.path_T)):
                
                # i. prices
                r = path_r[t]
                w = self.implied_w(path_r[t]) if path_w is None else path_w[t]

                # ii. next-period
                if t == par.path_T-1:
                    Va_p = sol.Va
                else:
                    Va_p = sol.path_Va[t+1]

                # ii. solve
                sol.path_m[t] = (1+r)*par.a_grid[np.newaxis,:] + w*par.e_grid[:,np.newaxis]

                # iii. time iteration
                time_iteration(par,r,w,Va_p,sol.path_Va[t],sol.path_a[t],sol.path_c[t],sol.path_m[t])

        if do_print:
            print(f'household problem solved in {elapsed(t0)}')
    def solve_household_ss(self,r,Va=None,do_print=False):
        """ solve the household problem in steady state """
        
        t0 = time.time()
        
        with jit(self) as model:
            
            par = model.par
            sol = model.sol
            
            # a. find wage from optimal firm behavior
            w = self.implied_w(r,par.Z)

            # b. initial guess
            sol.m[:,:] = (1+r)*par.a_grid[np.newaxis,:] + w*par.e_grid[:,np.newaxis]
            sol.Va[:,:] = (1+r)*(0.1*sol.m)**(-par.sigma) if Va is None else Va

            # c. solve
            it = solve_ss(par,sol,r,w)

            # d. indices and weights    
            find_i_and_w(par,sol.a,sol.i,sol.w)

        if do_print:

            if it >= 0:
                print(f'household problem solved in {elapsed(t0)} [{it} iterations]')
            else:
                print(f'household problem solution did not converge')

        return (it >= 0)
    def simulate_household_ss(self, D=None, do_print=False):
        """ gateway for simulating the model towards the steady state"""

        t0 = time.time()

        with jit(self) as model:

            success = True
            try:

                par = model.par
                sol = model.sol
                sim = model.sim

                # a. intial guess
                D = (np.repeat(par.e_ergodic, par.Na) /
                     par.Na).reshape(par.Ne, par.Na) if D is None else D

                # b. simulate
                it = simulate_ss(par, sol, sim, D)

                if do_print:
                    print(
                        f'household problem simulated in {elapsed(t0)} [{it} iterations]'
                    )

            except:

                success = False

        return success
Example #5
0
    def solve_vfi(self, do_print):
        """ solve the model with vfi """

        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol

            # a. last period (= consume all)
            sol.m[-1, :] = par.grid_m[-1, :]
            sol.c[-1, :] = sol.m[-1, :]
            for i, c in enumerate(sol.c[-1, :]):
                sol.inv_v[-1, i] = 1.0 / utility(c, par)

            # b. before last period
            for t in reversed(range(par.T - 1)):
                for i_m in range(par.Nm):

                    m = par.grid_m[t, i_m]

                    obj = lambda c: self.value_of_choice(c, t, m)
                    result = optimize.minimize_scalar(obj,
                                                      method='bounded',
                                                      bounds=(0, m))

                    sol.c[t, i_m] = result.x
                    sol.inv_v[t, i_m] = -1.0 / result.fun

                # save grid for m
                sol.m[t, :] = par.grid_m[t, :]

        if do_print:
            print(f'model solved in {elapsed(t0)}')
    def solve_G2EGM(self):
        """ solve with G2EGM """
        
        with jit(self) as model:

            par = model.par
            sol = model.sol

            if par.do_print:
                print('Solving with G2EGM:')

            # a. solve retirement
            t0 = time.time()

            retirement.solve(sol,par)

            if par.do_print:
                print(f'solved retirement problem in {time.time()-t0:.2f} secs')

            # b. solve last period working
            t0 = time.time()

            last_period.solve(sol,par)

            if par.do_print:
                print(f'solved last period working in {time.time()-t0:.2f} secs')

            # c. solve working
            for t in reversed(range(par.T-1)):
                
                t0 = time.time()
                
                if par.do_print:
                    print(f' t = {t}:')
                
                # i. post decision
                t0_w = time.time()

                post_decision.compute(t,sol,par)

                par.time_w[t] = time.time()-t0_w
                if par.do_print:
                    print(f'   computed post decision value function in {par.time_w[t]:.2f} secs')

                # ii. EGM
                t0_EGM = time.time()
                
                G2EGM.solve(t,sol,par)
                
                par.time_egm[t] = time.time()-t0_EGM
                if par.do_print:
                    print(f'   applied G2EGM  in {par.time_egm[t]:.2f} secs')

                par.time_work[t] = time.time()-t0

            if par.do_print:
                print(f'solved working problem in {np.sum(par.time_work):.2f} secs')
    def calculate_euler(self):
        """ calculate euler errors """
        
        with jit(self) as model:

            par = model.par
            sol = model.sol
            sim = model.sim

            simulate.euler(sim,sol,par)
    def solve_household_ss(self, r, Va=None, do_print=False):
        """ gateway for solving the model in steady state """

        t0 = time.time()

        with jit(self) as model:

            success = True
            try:

                par = model.par
                sol = model.sol

                # a. find wage from optimal firm behavior
                w = self.implied_w(r)

                # b. create (or re-create) grids
                self.create_grids()

                # c. solve
                sol.m[:, :] = (1 + r) * par.a_grid[
                    np.newaxis, :] + w * par.e_grid[:, np.newaxis]
                sol.Va[:, :] = (1 + r) * (0.1 * sol.m)**(
                    -par.sigma) if Va is None else Va

                it = 0
                while True:

                    # i. save
                    a_old = sol.a.copy()

                    # ii. egm
                    time_iteration(par, r, w, sol.Va, sol.Va, sol.a, sol.c,
                                   sol.m)

                    # ii. check
                    if np.max(np.abs(sol.a - a_old)) < par.solve_tol: break

                    # iv. increment
                    it += 1
                    if it > par.max_iter_solve:
                        raise Exception(
                            'too many iterations when solving for steady state'
                        )

            except:

                success = False

        if do_print:
            print(
                f'household problem solved in {elapsed(t0)} [{it} iterations]')

        return success
    def solve(self):
        """ solve the model using solmethod """

        with jit(self) as model: # can now call jitted functions

            par = model.par
            sol = model.sol

            # backwards induction
            for t in reversed(range(par.T)):
                
                t0 = time.time()
                
                # a. last period
                if t == par.T-1:
                    
                    last_period.solve(t,sol,par)

                # b. all other periods
                else:
                    
                    # i. compute post-decision functions
                    t0_w = time.time()

                    compute_w,compute_q = False,False
                    if par.solmethod in ['nvfi']: compute_w = True
                    elif par.solmethod in ['egm']: compute_q = True

                    if compute_w or compute_q:

                        if par.do_simple_w:
                            post_decision.compute_wq_simple(t,sol,par,compute_w=compute_w,compute_q=compute_q)
                        else:
                            post_decision.compute_wq(t,sol,par,compute_w=compute_w,compute_q=compute_q)

                    t1_w = time.time()

                    # ii. solve bellman equation
                    if par.solmethod == 'vfi':
                        vfi.solve_bellman(t,sol,par)                    
                    elif par.solmethod == 'nvfi':
                        nvfi.solve_bellman(t,sol,par)
                    elif par.solmethod == 'egm':
                        egm.solve_bellman(t,sol,par)                    
                    else:
                        raise ValueError(f'unknown solution method, {par.solmethod}')

                # c. print
                if par.do_print:
                    msg = f' t = {t} solved in {elapsed(t0)}'
                    if t < par.T-1:
                        msg += f' (w: {elapsed(t0_w,t1_w)})'                
                    print(msg)
    def calc_moments(self, do_timing=False):
        """ calculate moments """

        self.moms = {}

        for momname, infolist in self.specs.items():

            for info in infolist:
                args = info['args']

                if self.par.use_theoretical:
                    module = 'theoretical_moments'
                else:
                    module = 'moments'

                # i. skip if already calculated
                if (momname, args) in self.moms: continue

                # ii. potentially fast calculations
                t0 = time.time()

                found = getattr(eval(module), 'fast')(self, momname, args)

                t1 = time.time()
                if do_timing and found:
                    print(f'{momname:30s}: {t1-t0:2f} secs')

                if found: continue

                # iii. calculate
                t0 = time.time()

                with jit(self) as model:

                    if args == None:  # no arguments
                        model.moms[(momname,
                                    args)] = getattr(eval(module),
                                                     momname)(model.par,
                                                              model.sim)
                    elif np.isscalar(args):  # single argument
                        model.moms[(momname,
                                    args)] = getattr(eval(module),
                                                     momname)(model.par,
                                                              model.sim, args)
                    else:  # multiple arguments
                        model.moms[(momname,
                                    args)] = getattr(eval(module),
                                                     momname)(model.par,
                                                              model.sim, *args)

                t1 = time.time()
                if do_timing: print(f'{momname:30s}: {t1-t0:2f} secs')
    def solve_household_path(self,path_r,path_w,do_print=False):
        """ solve household problem along the transition path """

        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol

            solve_path(par,sol,path_r,path_w)

        if do_print:
            print(f'household problem solved in {elapsed(t0)}')
Example #12
0
def main_mixture_results(model, k, omegas):

    global skew, kurt, leq

    create_weights(model.par, k)

    with jit(model) as model:

        par = model.par
        sim = model.sim

        fill_weights(par, k)
        skew, kurt, leq = main_mixture_results_(par, sim, k, omegas)

    return skew, kurt, leq
    def simulate_household_path_jac(self,D0,dprice,do_print=False):
        """ gateway for simulating the model along path"""
        
        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol
            sim = model.sim        

            simulate_path_jac(par,sol,sim,D0,dprice)

        if do_print:
            print(f'household problem simulated in {elapsed(t0)}')
    def simulate_household_path(self, D0=None, do_print=False):
        """ gateway for simulating the model along path"""

        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol
            sim = model.sim

            # a. use steady state distribution if not specified
            D0 = sim.D if D0 is None else D0

            # b. simulate forward along path
            simulate_path(par, sol, sim, D0)

        if do_print:
            print(f'household problem simulated in {elapsed(t0)}')
Example #15
0
def cond_mixture_results(model, k, omegas):

    leq = np.nan * np.ones((3, omegas.size))
    create_weights(model.par, k)

    try:

        with jit(model) as model:

            par = model.par
            sim = model.sim

            fill_weights(par, k)
            leq[:, :] = cond_mixture_results_(par, sim, k, omegas)

    except:

        pass

    return leq
Example #16
0
    def simulate(self, do_print=True, seed=2017):
        """ simulate the model """

        with jit(self) as model:

            par = model.par
            sol = model.sol
            sim = model.sim

            t0 = time.time()

            # a. set seed
            if not seed is None: np.random.seed(seed)

            # b. shocks
            _shocki = np.random.choice(par.Nshocks,
                                       size=(par.simN, par.simT),
                                       p=par.w)
            sim.psi[:] = par.psi_vec[_shocki]
            sim.xi[:] = par.xi_vec[_shocki]

            # c. initial values
            sim.m[:, 0] = par.sim_mini
            sim.p[:, 0] = 0.0

            # d. simulation
            simulate_time_loop(par, sol, sim)

            # e. renomarlized
            sim.P[:, :] = np.exp(sim.p)
            sim.Y[:, :] = np.exp(sim.y)
            sim.M[:, :] = sim.m * sim.P
            sim.C[:, :] = sim.c * sim.P
            sim.A[:, :] = sim.a * sim.P

            if do_print:
                print(f'model simulated in {elapsed(t0)}')
Example #17
0
    def solve_egm(self, do_print):
        """ solve the model using egm """

        t0 = time.time()

        with jit(self) as model:

            par = model.par
            sol = model.sol

            # a. allocate working memory
            m = np.zeros(par.Na)
            c = np.zeros(par.Na)
            inv_v = np.zeros(par.Na)

            # b. last period (= consume all)
            sol.m[-1, :] = np.linspace(0, par.a_max, par.Na + 1)
            sol.c[-1, :] = sol.m[-1, :]
            sol.inv_v[-1, 0] = 0
            sol.inv_v[-1, 1:] = 1.0 / utility(sol.c[-1, 1:], par)

            # c. before last period
            for t in reversed(range(par.T - 1)):

                # i. solve by EGM
                egm(par, sol, t, m, c, inv_v)

                # ii. add zero consumption
                sol.m[t, 0] = par.a_min[t]
                sol.m[t, 1:] = m
                sol.c[t, 0] = 0
                sol.c[t, 1:] = c
                sol.inv_v[t, 0] = 0
                sol.inv_v[t, 1:] = inv_v

        if do_print:
            print(f'model solved in {elapsed(t0)}')
    def simulate(self):
        """ simulate model """

        with jit(self) as model: # can now call jitted functions 

            par = model.par
            sol = model.sol
            sim = model.sim
            
            t0 = time.time()

            # a. allocate memory and draw random numbers
            I = np.random.choice(par.Nshocks,
                size=(par.T,par.simN), 
                p=par.psi_w*par.xi_w)

            sim.psi[:] = par.psi[I]
            sim.xi[:] = par.xi[I]

            # b. simulate
            simulate.lifecycle(sim,sol,par)

        if par.do_print:
            print(f'model simulated in {elapsed(t0)}')
    def solve(self, do_assert=True):
        """ solve the model
        
        Args:

            do_assert (bool,optional): make assertions on the solution
        
        """

        if self.par.do_2d: return self.solve_2d()
        cpp = self.cpp

        tic = time.time()

        # backwards induction
        for t in reversed(range(self.par.T)):

            self.par.t = t

            with jit(self) as model:

                par = model.par
                sol = model.sol

                # i. last period
                if t == par.T - 1:

                    last_period.solve(t, sol, par)

                    if do_assert:
                        assert np.all((sol.c_keep[t] >= 0)
                                      & (np.isnan(sol.c_keep[t]) == False))
                        assert np.all((sol.inv_v_keep[t] >= 0)
                                      & (np.isnan(sol.inv_v_keep[t]) == False))
                        assert np.all((sol.d_adj[t] >= 0)
                                      & (np.isnan(sol.d_adj[t]) == False))
                        assert np.all((sol.c_adj[t] >= 0)
                                      & (np.isnan(sol.c_adj[t]) == False))
                        assert np.all((sol.inv_v_adj[t] >= 0)
                                      & (np.isnan(sol.inv_v_adj[t]) == False))

                # ii. all other periods
                else:

                    # o. compute post-decision functions
                    tic_w = time.time()

                    if par.solmethod in ['nvfi']:
                        post_decision.compute_wq(t, sol, par)
                    elif par.solmethod in ['negm']:
                        post_decision.compute_wq(t, sol, par, compute_q=True)
                    elif par.solmethod == 'nvfi_cpp':
                        cpp.compute_wq_nvfi(par, sol)
                    elif par.solmethod == 'negm_cpp':
                        cpp.compute_wq_negm(par, sol)

                    toc_w = time.time()
                    par.time_w[t] = toc_w - tic_w
                    if par.do_print:
                        print(f'  w computed in {toc_w-tic_w:.1f} secs')

                    if do_assert and par.solmethod in ['nvfi', 'negm']:
                        assert np.all((sol.inv_w[t] > 0)
                                      & (np.isnan(sol.inv_w[t]) == False)), t
                        if par.solmethod in ['negm']:
                            assert np.all((sol.q[t] > 0)
                                          & (np.isnan(sol.q[t]) == False)), t

                    # oo. solve keeper problem
                    tic_keep = time.time()

                    if par.solmethod == 'vfi':
                        vfi.solve_keep(t, sol, par)
                    elif par.solmethod == 'nvfi':
                        nvfi.solve_keep(t, sol, par)
                    elif par.solmethod == 'negm':
                        negm.solve_keep(t, sol, par)
                    elif par.solmethod == 'vfi_cpp':
                        cpp.solve_vfi_keep(par, sol)
                    elif par.solmethod == 'nvfi_cpp':
                        cpp.solve_nvfi_keep(par, sol)
                    elif par.solmethod == 'negm_cpp':
                        cpp.solve_negm_keep(par, sol)

                    toc_keep = time.time()
                    par.time_keep[t] = toc_keep - tic_keep
                    if par.do_print:
                        print(
                            f'  solved keeper problem in {toc_keep-tic_keep:.1f} secs'
                        )

                    if do_assert:
                        assert np.all((sol.c_keep[t] >= 0)
                                      & (np.isnan(sol.c_keep[t]) == False)), t
                        assert np.all((sol.inv_v_keep[t] >= 0) & (
                            np.isnan(sol.inv_v_keep[t]) == False)), t

                    # ooo. solve adjuster problem
                    tic_adj = time.time()

                    if par.solmethod == 'vfi':
                        vfi.solve_adj(t, sol, par)
                    elif par.solmethod in ['nvfi', 'negm']:
                        nvfi.solve_adj(t, sol, par)
                    elif par.solmethod == 'vfi_cpp':
                        cpp.solve_vfi_adj(par, sol)
                    elif par.solmethod in ['nvfi_cpp', 'negm_cpp']:
                        cpp.solve_nvfi_adj(par, sol)

                    toc_adj = time.time()
                    par.time_adj[t] = toc_adj - tic_adj
                    if par.do_print:
                        print(
                            f'  solved adjuster problem in {toc_adj-tic_adj:.1f} secs'
                        )

                    if do_assert:
                        assert np.all((sol.d_adj[t] >= 0)
                                      & (np.isnan(sol.d_adj[t]) == False)), t
                        assert np.all((sol.c_adj[t] >= 0)
                                      & (np.isnan(sol.c_adj[t]) == False)), t
                        assert np.all((sol.inv_v_adj[t] >= 0) &
                                      (np.isnan(sol.inv_v_adj[t]) == False)), t

                # iii. print
                toc = time.time()
                if par.do_print or par.do_print_period:
                    print(f' t = {t} solved in {toc-tic:.1f} secs')
    def solve_2d(self):
        """ solve the model """

        par = self.par
        sol = self.sol
        cpp = self.cpp

        # backwards induction
        for t in reversed(range(par.T)):

            par.t = t
            tic = time.time()

            # i. last period
            if t == par.T - 1:

                with jit(self) as model:
                    last_period.solve_2d(t, model.sol, model.par)

            # ii. all other periods
            else:

                # o. compute post-decision functions
                tic_w = time.time()

                if par.solmethod == 'nvfi_2d_cpp':
                    cpp.compute_wq_nvfi_2d(par, sol)
                elif par.solmethod == 'negm_2d_cpp':
                    cpp.compute_wq_negm_2d(par, sol)

                toc_w = time.time()
                par.time_w[t] = toc_w - tic_w
                if par.do_print:
                    print(f'  w computed in {toc_w-tic_w:.1f} secs')

                # oo. solve keeper problem
                tic_keep = time.time()

                if par.solmethod == 'vfi_2d_cpp':
                    cpp.solve_vfi_2d_keep(par, sol)
                elif par.solmethod == 'nvfi_2d_cpp':
                    cpp.solve_nvfi_2d_keep(par, sol)
                elif par.solmethod == 'negm_2d_cpp':
                    cpp.solve_negm_2d_keep(par, sol)

                toc_keep = time.time()
                par.time_keep[t] = toc_keep - tic_keep
                if par.do_print:
                    print(
                        f'  solved keeper problem in {toc_keep-tic_keep:.1f} secs'
                    )

                # ooo. solve adjuster problems
                tic_adj = time.time()

                if par.solmethod == 'vfi_2d_cpp':

                    tic_adj_full = time.time()
                    cpp.solve_vfi_2d_adj_full(par, sol)
                    par.time_adj_full[t] = time.time() - tic_adj_full
                    if par.do_print:
                        print(
                            f'  solved full adjuster problem {par.time_adj_full[t]:.1f} secs'
                        )

                    tic_adj_d1 = time.time()
                    cpp.solve_vfi_2d_adj_d1(par, sol)
                    par.time_adj_d1[t] = time.time() - tic_adj_d1
                    if par.do_print:
                        print(
                            f'  solved adjuster problem with free d1 {par.time_adj_d1[t]:.1f} secs'
                        )

                    tic_adj_d2 = time.time()
                    cpp.solve_vfi_2d_adj_d2(par, sol)
                    par.time_adj_d2[t] = time.time() - tic_adj_d2
                    if par.do_print:
                        print(
                            f'  solved adjuster problem with free d2 {par.time_adj_d2[t]:.1f} secs'
                        )

                elif par.solmethod in ['nvfi_2d_cpp', 'negm_2d_cpp']:

                    tic_adj_full = time.time()
                    cpp.solve_nvfi_2d_adj_full(par, sol)
                    par.time_adj_full[t] = time.time() - tic_adj_full
                    if par.do_print:
                        print(
                            f'  solved full adjuster problem {par.time_adj_full[t]:.1f} secs'
                        )

                    tic_adj_d1 = time.time()
                    cpp.solve_nvfi_2d_adj_d1(par, sol)
                    par.time_adj_d1[t] = time.time() - tic_adj_d1
                    if par.do_print:
                        print(
                            f'  solved adjuster problem with free d1 {par.time_adj_d1[t]:.1f} secs'
                        )

                    tic_adj_d2 = time.time()
                    cpp.solve_nvfi_2d_adj_d2(par, sol)
                    par.time_adj_d2[t] = time.time() - tic_adj_d2
                    if par.do_print:
                        print(
                            f'  solved adjuster problem with free d2 {par.time_adj_d2[t]:.1f} secs'
                        )

                toc_adj = time.time()
                par.time_adj[t] = toc_adj - tic_adj
                if par.do_print:
                    print(
                        f'  solved adjuster problems in {toc_adj-tic_adj:.1f} secs'
                    )

            # iii. print
            toc = time.time()
            if par.do_print or par.do_print_period:
                print(f' t = {t} solved in {toc-tic:.1f} secs')
    def simulate(self, do_utility=False, do_euler_error=False):
        """ simulate the model """

        par = self.par
        sol = self.sol
        sim = self.sim

        tic = time.time()

        # a. random shocks
        sim.p0[:] = np.random.lognormal(mean=0,
                                        sigma=par.sigma_p0,
                                        size=par.simN)
        if par.do_2d:
            sim.d10[:] = par.mu_d0 / 2 * np.random.lognormal(
                mean=0, sigma=par.sigma_d0, size=par.simN)
            sim.d20[:] = par.mu_d0 / 2 * np.random.lognormal(
                mean=0, sigma=par.sigma_d0, size=par.simN)
        else:
            sim.d0[:] = par.mu_d0 * np.random.lognormal(
                mean=0, sigma=par.sigma_d0, size=par.simN)
        sim.a0[:] = par.mu_a0 * np.random.lognormal(
            mean=0, sigma=par.sigma_a0, size=par.simN)

        I = np.random.choice(par.Nshocks,
                             size=(par.T, par.simN),
                             p=par.psi_w * par.xi_w)
        sim.psi[:, :] = par.psi[I]
        sim.xi[:, :] = par.xi[I]

        # b. call
        with jit(self) as model:

            par = model.par
            sol = model.sol
            sim = model.sim

            simulate.lifecycle(sim, sol, par)

        toc = time.time()

        if par.do_print:
            print(f'model simulated in {toc-tic:.1f} secs')

        # d. euler errors
        def norm_euler_errors(model):
            return np.log10(
                abs(model.sim.euler_error / model.sim.euler_error_c) + 1e-8)

        tic = time.time()
        if do_euler_error:

            with jit(self) as model:

                par = model.par
                sol = model.sol
                sim = model.sim

                simulate.euler_errors(sim, sol, par)

            sim.euler_error_rel[:] = norm_euler_errors(self)

        toc = time.time()
        if par.do_print:
            print(f'euler errors calculated in {toc-tic:.1f} secs')

        # e. utility
        tic = time.time()
        if do_utility:
            simulate.calc_utility(sim, sol, par)

        toc = time.time()
        if par.do_print:
            print(f'utility calculated in {toc-tic:.1f} secs')
 def simulate(self):
     with jit(self) as model:
         simulate.all(model)
    def solve_NEGM(self):
        """ solve with NEGM """
        
        with jit(self) as model:

            par = model.par
            sol = model.sol

            if par.do_print:
                print('Solving with NEGM:')

            # a. solve retirement
            t0 = time.time()

            retirement.solve(sol,par,G2EGM=False)

            if par.do_print:
                print(f'solved retirement problem in {time.time()-t0:.2f} secs')

            # b. solve last period working
            t0 = time.time()

            last_period.solve(sol,par,G2EGM=False)

            if par.do_print:
                print(f'solved last period working in {time.time()-t0:.2f} secs')

            # c. solve working  
            for t in reversed(range(par.T-1)):
                
                t0 = time.time()   
                
                if par.do_print:
                    print(f' t = {t}:')
                
                # i. post decision
                t0_w = time.time()

                post_decision.compute(t,sol,par,G2EGM=False)

                par.time_w[t] = time.time() - t0_w
                if par.do_print:
                    print(f'   computed post decision value function in {par.time_w[t]:.2f} secs')

                # ii. pure consumption problem
                t0_egm = time.time()
                
                NEGM.solve_pure_c(t,sol,par)
                
                par.time_egm[t] = time.time()-t0_egm
                if par.do_print:
                    print(f'   solved pure consumption problem in {par.time_egm[t]:.2f} secs')

                # iii. outer problem
                t0_vfi = time.time()
                
                NEGM.solve_outer(t,sol,par)
                
                par.time_vfi[t] = time.time()-t0_vfi
                if par.do_print:
                    print(f'   solved outer problem in {par.time_vfi[t] :.2f} secs')

                par.time_work[t] = time.time()-t0

            if par.do_print:
                print(f'solved working problem in {np.sum(par.time_work):.2f} secs')