예제 #1
0
def create_grids(par):

    # Check parameters
    assert (par.rho >= 0), 'not rho > 0'

    # Shocks
    par.xi, par.xi_w = tools.GaussHermite_lognorm(par.sigma_xi, par.Nxi)

    # End of period assets
    par.grid_a = np.nan + np.zeros([par.T, par.Na])
    for t in range(par.T):
        par.grid_a[t, :] = tools.nonlinspace(0 + 1e-6, par.a_max, par.Na,
                                             par.a_phi)

    # Cash-on-hand
    par.grid_m = np.concatenate([
        np.linspace(0 + 1e-6, 1 - 1e-6, par.Nm_b),
        tools.nonlinspace(1 + 1e-6, par.m_max, par.Nm - par.Nm_b, par.m_phi)
    ])  # Permanent income

    # Permanent income
    par.grid_p = tools.nonlinspace(0 + 1e-4, par.p_max, par.Np, par.p_phi)

    # Set seed
    np.random.seed(2020)

    return par
    def create_grids(self):

        par = self.par

        # Check parameters
        assert (par.rho >= 0), 'not rho > 0'

        # Shocks
        par.epsi, par.epsi_w = tools.GaussHermite_lognorm(par.sigma_w, par.Nw)

        # End of period assets
        par.grid_a = np.nan + np.zeros([par.T, par.Na])
        for t in range(par.T):
            par.grid_a[t, :] = tools.nonlinspace(0 + 1e-6, par.a_max, par.Na,
                                                 par.a_phi)

        # Cash-on-hand
        par.grid_m = np.concatenate([
            np.linspace(0 + 1e-6, 1 - 1e-6, par.Nm_b),
            tools.nonlinspace(1 + 1e-6, par.m_max, par.Nm - par.Nm_b,
                              par.m_phi)
        ])

        # Human capital
        par.grid_k = tools.nonlinspace(0 + 1e-4, par.k_max, par.Nk, par.k_phi)

        # Set seed
        np.random.seed(3)
    def create_grids(self):

        par = self.par

        # Check parameters
        #assert (par.rho >= 0), 'not rho > 0'
        #indsæt evt. andre parameter krav som stående herover

        # Shocks
        par.xi, par.xi_w = tools.GaussHermite_lognorm(par.sigma_w, par.Nxi)

        # End of period assets
        par.grid_a = np.nan + np.zeros([par.T, par.Na])
        for t in range(par.T):
            par.grid_a[t, :] = tools.nonlinspace(0 + 1e-6, par.a_max, par.Na,
                                                 par.a_phi)

        # Human capital
        par.grid_k = tools.nonlinspace(0 + 1e-4, par.k_max, par.Nk, par.k_phi)

        # Set seed
        np.random.seed(2021)
예제 #4
0
    def create_grids(self):

        par = self.par
        # Check parameters
        assert (par.rho >= 0), 'not rho > 0'
        # need more checks?

        # Shocks for wage, this gives us the shock and weight!
        par.xi, par.xi_w = GaussHermite_lognorm(par.sigma_xi, par.Nxi)

        # Setting up grids
        # We set up a grid for A, which is the exogeneously fixed monotonic grid over savings
        # End pf period assets
        par.grid_a = np.nan + np.zeros([par.T, par.Na])
        for t in range(par.T):
            par.grid_a[t, :] = tools.nonlinspace(0 + 1e-8, par.a_max, par.Na,
                                                 par.a_phi)

        # We need a grid for human capital (K)
        par.grid_k = tools.nonlinspace(0 + 1e-4, par.k_max, par.Nk, par.k_phi)
        #par.grid_k = np.nan + np.zeros([par.T,par.Nk])
        #for t in range(par.T):
        #    par.grid_k[t,:] = tools.nonlinspace(0+1e-8,par.k_max,par.Nk,par.k_phi)

        #Grid for m?
        #par.grid_m = np.nan + np.zeros([par.T,par.Nm])
        #for t in range(par.T):
        #    par.grid_m[t,:] = tools.nonlinspace(0+1e-8,par.m_max,par.Nm,par.m_phi)
        par.grid_m = np.concatenate([
            np.linspace(0 + 1e-6, 1 - 1e-6, par.Nm_b),
            tools.nonlinspace(1 + 1e-6, par.m_max, par.Nm - par.Nm_b,
                              par.m_phi)
        ])
        #par.grid_m =  tools.nonlinspace(0+1e-4,par.m_max,par.Nm,par.m_phi)
        # Set seed
        np.random.seed(2021)
def setup():
    class par:
        pass

    # Demograhpics
    par.age_min = 25  # Only relevant for figures
    par.T = 85 - par.age_min
    par.Tr = 60 - par.age_min  # Retirement age, no retirement if TR=T

    # children
    par.num_n = 3  # maximum number of children
    par.age_fer = 45  # maximum age of fertility

    # Preferences
    par.rho = 0.5
    par.beta = 0.96

    # Income parameters
    par.G = 1.03
    par.num_M = 50
    par.M_max = 10
    par.grid_M = tools.nonlinspace(
        1.0e-6, par.M_max, par.num_M,
        1.1)  # non-linear spaced points: like np.linspace with unequal spacing

    par.sigma_xi = 0.1
    par.sigma_psi = 0.1

    par.low_p = 0.005  # Called pi in slides
    par.low_val = 0  # Called mu in slides.

    # Saving and borrowing
    par.R = 1.04
    par.kappa = 0.0

    # Numerical integration
    par.Nxi = 8  # number of quadrature points for xi
    par.Npsi = 8  # number of quadrature points for psi

    # 6. simulation
    par.sim_mini = 2.5  # initial m in simulation
    par.simN = 500000  # number of persons in simulation
    par.simT = 100  # number of periods in simulation

    return par
예제 #6
0
def create_grids(par):

    # Check parameters
    assert (par.rho >= 0), 'not rho > 0'

    # Shocks
    par.xi, par.xi_w = tools.GaussHermite_lognorm(par.sigma_xi, par.Nxi)

    # End of period assets
    par.grid_a = np.nan + np.zeros([par.T, par.Na])
    for t in range(par.T):
        par.grid_a[t, :] = tools.nonlinspace(0 + 1e-8, par.a_max, par.Na,
                                             par.a_phi)

    # Set seed
    np.random.seed(2020)

    return par
예제 #7
0
    def create_grids(self):

        par = self.par

        # Check parameters
        assert (par.rho >= 0), 'not rho > 0'

        # Shocks. Draw from lognormal distribution, return gaussian weights and nodes
        par.xi, par.xi_w = tools.GaussHermite_lognorm(par.sigma_xi, par.Nxi)

        # End of period assets
        par.grid_a = np.nan + np.zeros([par.T, par.Na])
        for t in range(par.T):
            par.grid_a[t, :] = tools.nonlinspace(0 + 1e-8, par.a_max, par.Na,
                                                 par.a_phi)

        # Set seed
        np.random.seed(2020)
예제 #8
0
파일: model.py 프로젝트: emilblicher/DP2021
    def create_grids(self):

        par = self.par
        #1. Check parameters
        assert (par.rho >= 0), 'not rho > 0'
        assert (par.lambdaa >= 0), 'not lambda > 0'

        #2. Shocks
        eps,eps_w = tools.GaussHermite_lognorm(par.sigma_xi,par.Nxi)
        par.psi,par.psi_w = tools.GaussHermite_lognorm(par.sigma_psi,par.Npsi)

            #define xi
        if par.low_p > 0:
            par.xi =  np.append(par.low_val+1e-8, (eps-par.low_p*par.low_val)/(1-par.low_p), axis=None) # +1e-8 makes it possible to take the log in simulation if low_val = 0
            par.xi_w = np.append(par.low_p, (1-par.low_p)*eps_w, axis=None)
        else:
            par.xi = eps
            par.xi_w = eps_w

            #Vectorize all
        par.xi_vec = np.tile(par.xi,par.psi.size)       # Repeat entire array x times
        par.psi_vec = np.repeat(par.psi,par.xi.size)    # Repeat each element of the array x times
        par.xi_w_vec = np.tile(par.xi_w,par.psi.size)
        par.psi_w_vec = np.repeat(par.psi_w,par.xi.size)

        par.w = par.xi_w_vec * par.psi_w_vec
        assert (1-sum(par.w) < 1e-8), 'the weights does not sum to 1'
        
        par.Nshocks = par.w.size    # count number of shock nodes
        
        #3. Minimum a
        if par.lambdaa == 0:
            par.a_min = np.zeros([par.T,1])
        else:

            #Using formula from slides
            psi_min = min(par.psi)
            xi_min = min(par.xi)
            par.a_min = np.nan + np.zeros([par.T,1])
            for t in range(par.T-1,-1,-1):
                if t >= par.Tr:
                    Omega = 0  # No debt in final period
                elif t == par.T-1:
                    Omega = par.R**(-1)*par.G*par.L[t+1]*psi_min*xi_min
                else: 
                    Omega = par.R**(-1)*(min(Omega,par.lambdaa)+xi_min)*par.G*par.L[t+1]*psi_min
                
                par.a_min[t]=-min(Omega,par.lambdaa)*par.G*par.L[t+1]*psi_min
        
        
        #4. End of period assets
        par.grid_a = np.nan + np.zeros([par.T,par.Na])
        for t in range(par.T):
            par.grid_a[t,:] = tools.nonlinspace(par.a_min[t]+1e-8,par.a_max,par.Na,par.a_phi)


        #5.  Conditions
        par.FHW = par.G/par.R
        par.AI = (par.R*par.beta)**(1/par.rho)
        par.GI = par.AI*sum(par.w*par.psi_vec**(-1))/par.G
        par.RI = par.AI/par.R      
        par.WRI = par.low_p**(1/par.rho)*par.AI/par.R
        par.FVA = par.beta*sum(par.w*(par.G*par.psi_vec)**(1-par.rho))

        # 6. Set seed
        np.random.seed(2020)
예제 #9
0
def run_model(par):
    # 1. Prepare grids and allocate solution
    # Gauss Hermite
    psi, psi_w = tools.GaussHermite_lognorm(sigma=par.sigma_psi, n=par.Npsi)
    xi, xi_w = tools.GaussHermite_lognorm(sigma=par.sigma_xi, n=par.Nxi)
    d, d_w = tools.GaussHermite_lognorm(sigma=par.sigma_d, n=par.Nd)

    # Add low income shock to xi
    if par.pi > 0:
        # Weights
        xi_w *= (1.0 - par.pi)
        xi_w = np.insert(xi_w, 0, par.pi)

        # Values
        xi = (xi - par.mu * par.pi) / (1.0 - par.pi)
        xi = np.insert(xi, 0, par.mu)

    # Vectorize tensor product of shocks and total weight
    psi_vec, xi_vec, d_vec = np.meshgrid(psi, xi, d, indexing='ij')
    psi_w_vec, xi_w_vec, d_w_vec = np.meshgrid(psi_w, xi_w, d_w, indexing='ij')

    par.psi_vec = psi_vec.ravel()
    par.xi_vec = xi_vec.ravel()
    par.d_vec = d_vec.ravel()
    par.w = xi_w_vec.ravel() * psi_w_vec.ravel() * d_w_vec.ravel()

    assert 1 - np.sum(par.w) < 1e-8  # Check if weights sum to 1

    # Count number of shock nodes
    par.Nshocks = par.w.size

    # Create grids
    par.grid_m = tools.nonlinspace(1e-6, par.m_max, par.Nm, par.m_phi)
    par.grid_h = tools.nonlinspace(0, par.h_max, par.Nh, par.h_phi)

    # Create solution and simulation
    spec = [('c', double[:, :, :, :]), ('inv_v', double[:, :, :, :])]

    @jitclass(spec)
    class sol_:
        def __init__(self):
            pass

    spec = [('m', double[:, :]), ('c', double[:, :]), ('a', double[:, :]),
            ('p', double[:, :]), ('y', double[:, :]), ('psi', double[:, :]),
            ('xi', double[:, :]), ('d', double[:, :]), ('P', double[:, :]),
            ('Y', double[:, :]), ('M', double[:, :]), ('C', double[:, :]),
            ('A', double[:, :]), ('z', int32[:]), ('h', double[:, :]),
            ('H', double[:, :])]

    @jitclass(spec)
    class sim_:
        def __init__(self):
            pass

    # Allocate memory for solution
    sol = sol_()
    sol_shape = (par.T, par.Nm, 2, par.Nh)
    sol.c = np.zeros(sol_shape)
    sol.inv_v = np.zeros(sol_shape)

    # Allocate memory for simulation
    sim = sim_()
    sim_shape = (par.simN, par.T)
    sim.m = np.zeros(sim_shape)
    sim.c = np.zeros(sim_shape)
    sim.a = np.zeros(sim_shape)
    sim.p = np.zeros(sim_shape)
    sim.y = np.zeros(sim_shape)
    sim.psi = np.zeros(sim_shape)
    sim.xi = np.zeros(sim_shape)
    sim.d = np.zeros(sim_shape)
    sim.P = np.zeros(sim_shape)
    sim.Y = np.zeros(sim_shape)
    sim.M = np.zeros(sim_shape)
    sim.C = np.zeros(sim_shape)
    sim.A = np.zeros(sim_shape)
    sim.z = np.zeros((par.simN), dtype=int)
    sim.h = np.zeros(sim_shape)
    sim.H = np.zeros(sim_shape)

    # 2. Solve model by VFI
    @njit
    def utility(c, rho):
        return c**(1 - rho) / (1 - rho)

    @njit
    def marg_utility(c, rho):
        return c**(-rho)

    @njit
    def inv_marg_utility(u, rho):
        return u**(-1 / rho)

    @njit
    def value_of_choice(c, t, m, h, z, par, sol):
        # End-of-period assets
        a = m - c

        # Calculate inverse value-of-choice in next period
        still_working_next_period = t + 1 <= par.TR - 1
        if still_working_next_period:
            fac_vec = par.G[t] * par.psi_vec
            w = par.w
            xi = par.xi_vec

            if t + 1 < par.TH - 1:
                h_plus_vec = np.repeat(0.0, len(fac_vec))
            elif t + 1 == par.TH - 1:
                #h_plus_vec = np.repeat(0.0, len(fac))
                h_plus_vec = np.repeat(par.alpha, len(fac_vec))
            elif t + 1 > par.TH - 1 and z == 0:
                h_plus_vec = np.repeat(
                    ((par.d_vec + par.delta - 1) / fac_vec) * h, len(fac_vec))
            elif t + 1 > par.TH - 1 and z == 1:
                h_plus_vec = np.repeat(0.0, len(fac_vec))

            if t + 1 == par.TH and z == 1:
                h_term_vec = ((par.d_vec + par.delta - 1) / fac_vec) * h
            else:
                h_term_vec = np.repeat(0.0, len(fac_vec))

            m_plus_vec = (par.R / fac_vec) * a + xi + h_term_vec
            inv_v_plus_vec = tools.interp_2d_vec(par.grid_m, par.grid_h,
                                                 sol.inv_v[t + 1, :, z, :],
                                                 m_plus_vec, h_plus_vec)
        else:
            fac = par.G[t]

            if t + 1 == par.TR and z == 0:
                h_term = (par.delta / fac) * h
            else:
                h_term = 0.0

            m_plus = (par.R / fac) * a + 1 + h_term
            inv_v_plus = tools.interp_2d(par.grid_m, par.grid_h,
                                         sol.inv_v[t + 1, :,
                                                   z, :], m_plus, 0.0)

        # Value-of-choice
        if still_working_next_period:
            v_plus_vec = 1 / inv_v_plus_vec
            total = utility(c, par.rho) + par.beta * np.sum(
                w * fac_vec**(1 - par.rho) * v_plus_vec)
        else:
            v_plus = 1 / inv_v_plus
            total = utility(c,
                            par.rho) + par.beta * fac**(1 - par.rho) * v_plus
        return -total

    # Last period (= consume all)
    for z in [0, 1]:
        for i_h in range(par.Nh):
            sol.c[-1, :, z, i_h] = par.grid_m
            for i, c in enumerate(sol.c[-1, :, z, i_h]):
                sol.inv_v[-1, i, z, i_h] = 1.0 / utility(c, par.rho)

        # Before last period
        for t in reversed(range(par.T - 1)):
            for i_h, h in enumerate(par.grid_h):
                if t >= par.TR and i_h != 0:  # After retirement, solution is independent of z and h
                    sol.c[t, :, z, i_h] = sol.c[t, :, 0, 0]
                    sol.inv_v[t, :, z, i_h] = sol.inv_v[t, :, 0, 0]
                elif t >= par.TH and z == 1 and i_h != 0:  # After early holiday pay, solution is independent of h
                    sol.c[t, :, z, i_h] = sol.c[t, :, z, 0]
                    sol.inv_v[t, :, z, i_h] = sol.inv_v[t, :, z, 0]
                elif t < par.TH - 1 and i_h != 0:  # Before holiday pay decision, solution is indenpendent of z and h
                    sol.c[t, :, z, i_h] = sol.c[t, :, 0, 0]
                    sol.inv_v[t, :, z, i_h] = sol.inv_v[t, :, 0, 0]
                else:
                    for i_m, m in enumerate(par.grid_m):
                        obj = lambda c: value_of_choice(
                            c, t, m, h, z, par, sol)
                        result = optimize.minimize_scalar(obj,
                                                          method='bounded',
                                                          bounds=(0, m))

                        sol.c[t, i_m, z, i_h] = result.x
                        sol.inv_v[t, i_m, z, i_h] = -1.0 / result.fun

    # 3. Simulate model

    # Set seed
    np.random.seed(par.seed)

    # Shocks
    _shocki = np.random.choice(par.Nshocks, size=(par.simN, par.T), p=par.w)
    sim.psi[:] = par.psi_vec[_shocki]
    sim.xi[:] = par.xi_vec[_shocki]
    sim.d[:] = par.d_vec[_shocki]

    # Initial values
    sim.m[:, 0] = par.sim_mini
    sim.p[:, 0] = 0.0

    @njit(parallel=True)
    def simulate_time_loop(par, sol, sim):
        # Unpack (helps numba)
        m = sim.m
        p = sim.p
        y = sim.y
        c = sim.c
        a = sim.a
        h = sim.h
        z = sim.z

        # loop over first households and then time
        for i in prange(par.simN):
            for t in range(par.T):
                # Consumption
                c[i, t] = tools.interp_2d(par.grid_m, par.grid_h,
                                          sol.c[t, :, z[i], :], m[i, t], h[i,
                                                                           t])
                a[i, t] = m[i, t] - c[i, t]

                if t < par.T - 1:
                    still_working_next_period = t + 1 <= par.TR - 1
                    if still_working_next_period:
                        fac = par.G[t] * sim.psi[i, t + 1]
                        xi = sim.xi[i, t + 1]

                        if t + 1 < par.TH - 1:
                            h[i, t + 1] = 0.0
                        elif t + 1 == par.TH - 1:
                            h[i, t + 1] = par.alpha
                        elif t + 1 > par.TH - 1 and z[i] == 0:
                            h[i, t + 1] = ((sim.d[i, t + 1] + par.delta - 1) /
                                           fac) * h[i, t]
                        elif t + 1 > par.TH - 1 and z[i] == 1:
                            h[i, t + 1] = 0.0

                        if t + 1 == par.TH and z[i] == 1:
                            h_term = ((sim.d[i, t + 1] + par.delta - 1) /
                                      fac) * h[i, t]
                        else:
                            h_term = 0.0

                        m[i, t + 1] = (par.R / fac) * a[i, t] + xi + h_term
                        p[i, t + 1] = np.log(par.G[t]) + p[i, t] + np.log(
                            sim.psi[i, t + 1])
                        if sim.xi[i, t + 1] > 0:
                            y[i,
                              t + 1] = p[i, t + 1] + np.log(sim.xi[i, t + 1])
                        else:
                            y[i, t + 1] = -np.inf
                    else:
                        fac = par.G[t]
                        xi = 1

                        h[i, t + 1] = 0.0

                        if t + 1 == par.TR and z[i] == 0:
                            h_term = (par.delta / fac) * h[i, t]
                        else:
                            h_term = 0.0

                        m[i, t + 1] = (par.R / fac) * a[i, t] + xi + h_term
                        p[i, t + 1] = np.log(par.G[t]) + p[i, t]
                        y[i, t + 1] = p[i, t + 1]

                    if t + 1 == par.TH - 1:
                        if par.z_mode == 2:
                            inv_v0 = tools.interp_2d(par.grid_m, par.grid_h,
                                                     sol.inv_v[t + 1, :, 0, :],
                                                     m[i, t + 1], h[i, t + 1])
                            inv_v1 = tools.interp_2d(par.grid_m, par.grid_h,
                                                     sol.inv_v[t + 1, :, 1, :],
                                                     m[i, t + 1], h[i, t + 1])
                            if inv_v0 < inv_v1:
                                z[i] = 0
                            else:
                                z[i] = 1
                        elif par.z_mode == 1:
                            z[i] = 1
                        elif par.z_mode == 0:
                            z[i] = 0

    # Simulate model
    simulate_time_loop(par, sol, sim)

    # Renormalize
    sim.P[:, :] = np.exp(sim.p)
    sim.Y[:, :] = np.exp(sim.y)
    sim.M[:, :] = sim.m * sim.P
    sim.C[:, :] = sim.c * sim.P
    sim.A[:, :] = sim.a * sim.P
    sim.H[:, :] = sim.h * sim.P

    return sol, sim
예제 #10
0
def run_model(par):
    # 1. Prepare grids and allocate solution
    # Gauss Hermite
    psi, psi_w = tools.GaussHermite_lognorm(sigma=par.sigma_psi,n=par.Npsi)
    xi, xi_w = tools.GaussHermite_lognorm(sigma=par.sigma_xi,n=par.Nxi)
    d, d_w = tools.GaussHermite_lognorm(sigma=par.sigma_d, n=par.Nd)
    
    # Add low income shock to xi
    if par.pi > 0:
        # Weights
        xi_w *= (1.0-par.pi)
        xi_w = np.insert(xi_w,0,par.pi)

        # Values
        xi = (xi-par.mu*par.pi)/(1.0-par.pi)
        xi = np.insert(xi,0,par.mu)

    # Vectorize tensor product of shocks and total weight
    psi_vec,xi_vec,d_vec = np.meshgrid(psi,xi,d,indexing='ij')
    psi_w_vec,xi_w_vec,d_w_vec = np.meshgrid(psi_w,xi_w,d_w,indexing='ij')

    par.psi_vec = psi_vec.ravel()
    par.xi_vec = xi_vec.ravel()
    par.d_vec = d_vec.ravel()
    par.w = xi_w_vec.ravel()*psi_w_vec.ravel()*d_w_vec.ravel()

    assert 1-np.sum(par.w) < 1e-8 # Check if weights sum to 1

    # Count number of shock nodes
    par.Nshocks = par.w.size

    # Create grids
    par.grid_h = tools.nonlinspace(0,par.h_max,par.Nh,par.h_phi)
    par.grid_a = tools.nonlinspace(1e-6,par.m_max,par.Nm,par.m_phi)
     
    # Create solution and simulation
    spec = [('m', double[:,:,:,:]), ('c', double[:,:,:,:]), ('inv_v', double[:,:,:,:])]
    @jitclass(spec)
    class sol_:
        def __init__(self): pass
        
    spec = [('m', double[:,:]), ('c', double[:,:]), ('a', double[:,:]), ('p', double[:,:]), ('y', double[:,:]),
            ('psi', double[:,:]), ('xi', double[:,:]), ('d', double[:,:]), ('P', double[:,:]), ('Y', double[:,:]),
            ('M', double[:,:]), ('C', double[:,:]), ('A', double[:,:]), ('z', int32[:]), ('h', double[:,:]),
            ('H', double[:,:])]
    @jitclass(spec)
    class sim_:
        def __init__(self): pass

    # Allocate memory for solution
    sol = sol_()
    sol_shape = (par.T,par.Na+1,2,par.Nh)
    sol.m = np.zeros(sol_shape)
    sol.c = np.zeros(sol_shape)
    sol.inv_v = np.zeros(sol_shape)
    
    # Allocate memory for simulation
    sim = sim_()
    sim_shape = (par.simN,par.T)
    sim.m = np.zeros(sim_shape)
    sim.c = np.zeros(sim_shape)
    sim.a = np.zeros(sim_shape)
    sim.p = np.zeros(sim_shape)
    sim.y = np.zeros(sim_shape)
    sim.psi = np.zeros(sim_shape)
    sim.xi = np.zeros(sim_shape)
    sim.d = np.zeros(sim_shape)
    sim.P = np.zeros(sim_shape)
    sim.Y = np.zeros(sim_shape)
    sim.M = np.zeros(sim_shape)
    sim.C = np.zeros(sim_shape)
    sim.A = np.zeros(sim_shape)
    sim.z = np.zeros((par.simN), dtype=int)
    sim.h = np.zeros(sim_shape)
    sim.H = np.zeros(sim_shape)
    
    # 2. Solve model by EGM
    @njit
    def utility(c, rho):
        return c**(1-rho)/(1-rho)    
    
    @njit
    def marg_utility(c, rho):
        return c**(-rho)      
    
    @njit
    def inv_marg_utility(u, rho):
        return u**(-1/rho)
    
    
    @njit
    def solve_egm(par,sol,sim):
        # Last period (= consume all)
        for z in [0, 1]:
            for i_h in range(par.Nh):
                sol.m[-1,:,z,i_h] = np.linspace(0,par.a_max,par.Na+1)
                sol.c[-1,:,z,i_h] = sol.m[-1,:,z,i_h]
                sol.inv_v[-1,0,z,i_h] = 0.0
                sol.inv_v[-1,1:,z,i_h] = 1.0/utility(sol.c[-1,1:,z,i_h],par.rho)
    
            # Before last period
            for t in range(par.T-2,-1,-1):
                for i_h, h in enumerate(par.grid_h):
                    # i. solve by EGM
                    # loop over end-of-period assets
                    for i_a in range(1,par.Na+1):
                
                        a = par.grid_a[i_a-1]
                        still_working_next_period = t+1 <= par.TR-1
                            
                        # a. prep
                        if still_working_next_period:
                            fac_vec = par.G[t]*par.psi_vec
                            w = par.w
                            xi = par.xi_vec
                            
                            if t+1 < par.TH-1:
                                h_plus_vec = np.repeat(0.0, len(fac_vec))
                            elif t+1 == par.TH-1:
                                h_plus_vec = np.repeat(par.alpha, len(fac_vec))
                            elif t+1 > par.TH-1 and z == 0:
                                h_plus_vec = np.repeat(((par.d_vec+par.delta-1)/fac_vec)*h, len(fac_vec))
                            elif t+1 > par.TH-1 and z == 1:
                                h_plus_vec = np.repeat(0.0, len(fac_vec))
                            
                            if t+1 == par.TH and z == 1:
                                h_term_vec = ((par.d_vec+par.delta-1)/fac_vec)*h
                            else:
                                h_term_vec = np.repeat(0.0,len(fac_vec))
                                
                            m_plus_vec = (par.R/fac_vec)*a + xi + h_term_vec
                            inv_v_plus_vec = tools.interp_2d_vec(sol.m[t+1,:,z,i_h], par.grid_h, sol.inv_v[t+1,:,z,:], m_plus_vec, h_plus_vec)
                            
                            # b. future m and c
                            c_plus_vec = tools.interp_2d_vec(sol.m[t+1,:,z,i_h], par.grid_h, sol.c[t+1,:,z,:], m_plus_vec, h_plus_vec)
                            v_plus_vec = 1.0/inv_v_plus_vec
                            
                            # c. average future marginal utility
                            marg_u_plus_vec = marg_utility(fac_vec*c_plus_vec,par.rho)
                            avg_marg_u_plus_vec = np.sum(w*marg_u_plus_vec)
                            avg_v_plus_vec = np.sum(w*(fac_vec**(1-par.rho))*v_plus_vec)
                            
                            # d. current c
                            sol.c[t,i_a,z,i_h] = inv_marg_utility(par.beta*par.R*avg_marg_u_plus_vec,par.rho)
                            
                            # e. current m
                            sol.m[t,i_a,z,i_h] = a + sol.c[t,i_a,z,i_h]
                            
                            # f. current v
                            if sol.c[t,i_a,z,i_h] > 0:
                                sol.inv_v[t,i_a,z,i_h] = 1.0/(utility(sol.c[t,i_a,z,i_h],par.rho) + par.beta*avg_v_plus_vec)
                            else:
                                sol.inv_v[t,i_a,z,i_h] = 0
                            
                        else:
                            fac = par.G[t]
    
                            if t+1 == par.TR and z == 0:
                                h_term = (par.delta/fac)*h
                            else:
                                h_term = 0.0
                            
                            m_plus = (par.R/fac)*a + 1 + h_term
                            inv_v_plus = tools.interp_2d(sol.m[t+1,:,z,i_h], par.grid_h, sol.inv_v[t+1,:,z,:], m_plus, 0.0)
                            
                            # b. future m and c
                            c_plus = tools.interp_2d(sol.m[t+1,:,z,i_h], par.grid_h, sol.c[t+1,:,z,:], m_plus, 0.0)
                            v_plus = 1.0/inv_v_plus
                            
                            # c. average future marginal utility
                            marg_u_plus = marg_utility(fac*c_plus,par.rho)
                            avg_marg_u_plus = marg_u_plus
                            avg_v_plus = (fac**(1-par.rho))*v_plus
                            
                            # d. current c
                            sol.c[t,i_a,z,i_h] = inv_marg_utility(par.beta*par.R*avg_marg_u_plus,par.rho)
                            
                            # e. current m
                            sol.m[t,i_a,z,i_h] = a + sol.c[t,i_a,z,i_h]
                            
                            # f. current v
                            if sol.c[t,i_a,z,i_h] > 0:
                                sol.inv_v[t,i_a,z,i_h] = 1.0/(utility(sol.c[t,i_a,z,i_h],par.rho) + par.beta*avg_v_plus)
                            else:
                                sol.inv_v[t,i_a,z,i_h] = 0
    
    
                    # ii. add zero consumption
                    sol.m[t,0,z,i_h] = 0.0
                    sol.c[t,0,z,i_h] = 0.0
                    sol.inv_v[t,0,z,i_h] = 0.0

    # Solve by EGM
    solve_egm(par,sol,sim)

    # 3. Simulate model
        
    # Set seed
    np.random.seed(par.seed)

    # Shocks
    _shocki = np.random.choice(par.Nshocks,size=(par.simN,par.T),p=par.w)
    sim.psi[:] = par.psi_vec[_shocki]
    sim.xi[:] = par.xi_vec[_shocki]
    sim.d[:] = par.d_vec[_shocki]

    # Initial values
    sim.m[:,0] = par.sim_mini 
    sim.p[:,0] = 0.0
    
    @njit
    def nearest(points,target,norm_fact):
        target[0] = target[0]/norm_fact
        dist = np.sum((points - target)**2, axis=1)
        return np.argmin(dist)
    
    @njit(parallel=True)
    def simulate_time_loop(par,sol,sim):
        # Prepare irregular interpolation
        num_points = par.Nh*(par.Nm+1)
        points = np.zeros((par.T,2,num_points,2))
        values_c = np.zeros((par.T,2,num_points))
        values_inv_v = np.zeros((par.T,2,num_points))
        for t in range(par.T):
            for z in [0,1]:
                for i_h,h_loop in enumerate(par.grid_h):
                    m_grid = sol.m[t,:,z,i_h]
                    for i_m,m_loop in enumerate(m_grid):
                        values_c[t,z,i_h*(par.Nm+1)+i_m] = sol.c[t,i_m,z,i_h]
                        values_inv_v[t,z,i_h*(par.Nm+1)+i_m] = sol.inv_v[t,i_m,z,i_h]
                    points[t,z,(par.Nm+1)*i_h:(par.Nm+1)*(i_h+1),:] = np.array(list(zip(m_grid, [h_loop]*len(m_grid))))        

        # Normalize points
        norm_fact = par.m_max/par.h_max
        points[:,:,:,0] = points[:,:,:,0]/norm_fact

        # Unpack (helps numba)
        m = sim.m
        p = sim.p
        y = sim.y
        c = sim.c
        a = sim.a
        h = sim.h
        z = sim.z

        # loop over first households and then time
        for i in prange(par.simN):
            for t in range(par.T):
                # Consumption
                #c[i,t] = griddata(points[t,z[i],:,:], values_c[t,z[i],:], (m[i,t], h[i,t]), method='linear')*1.0                    
                nearest_row = nearest(points[t,z[i],:,:], np.array([m[i,t], h[i,t]]), norm_fact)
                c[i,t] = values_c[t,z[i],nearest_row]
                a[i,t] = m[i,t] - c[i,t]

                if t < par.T-1:
                    still_working_next_period = t+1 <= par.TR-1
                    if still_working_next_period:
                        fac = par.G[t]*sim.psi[i,t+1]
                        xi = sim.xi[i,t+1]

                        if t+1 < par.TH-1:
                            h[i,t+1] = 0.0
                        elif t+1 == par.TH-1:
                            h[i,t+1] = par.alpha
                        elif t+1 > par.TH-1 and z[i] == 0:
                            h[i,t+1] = ((sim.d[i,t+1]+par.delta-1)/fac)*h[i,t]
                        elif t+1 > par.TH-1 and z[i] == 1:
                            h[i,t+1] = 0.0

                        if t+1 == par.TH and z[i] == 1:
                            h_term = ((sim.d[i,t+1]+par.delta-1)/fac)*h[i,t]
                        else:
                            h_term = 0.0

                        m[i,t+1] = (par.R/fac)*a[i,t] + xi + h_term
                        p[i,t+1] = np.log(par.G[t]) + p[i,t] + np.log(sim.psi[i,t+1])   
                        if sim.xi[i,t+1] > 0:
                            y[i,t+1] = p[i,t+1] + np.log(sim.xi[i,t+1])
                        else:
                            y[i,t+1] = -np.inf
                    else:
                        fac = par.G[t]
                        xi = 1

                        h[i,t+1] = 0.0

                        if t+1 == par.TR and z[i] == 0:
                            h_term = (par.delta/fac)*h[i,t]
                        else:
                            h_term = 0.0

                        m[i,t+1] = (par.R/fac)*a[i,t] + xi + h_term
                        p[i,t+1] = np.log(par.G[t]) + p[i,t]
                        y[i,t+1] = p[i,t+1]

                    if t+1 == par.TH-1:
                        if par.z_mode == 2:
                            nearest_row = nearest(points[t+1,z[i],:,:], np.array([m[i,t+1], h[i,t+1]]), norm_fact)
                            inv_v0 = values_inv_v[t+1,0,nearest_row]
                            inv_v1 = values_inv_v[t+1,1,nearest_row]
                            #inv_v0 = griddata(points[t,0,:,:], values_inv_v[t,0,:], (m[i,t+1], h[i,t+1]), method='linear')*1.0
                            #inv_v1 = griddata(points[t,1,:,:], values_inv_v[t,1,:], (m[i,t+1], h[i,t+1]), method='linear')*1.0
                            if inv_v0 < inv_v1:
                                z[i] = 0
                            else:
                                z[i] = 1
                        elif par.z_mode == 1:
                            z[i] = 1
                        elif par.z_mode == 0:
                            z[i] = 0
                            
                            
    simulate_time_loop(par,sol,sim)

    # Renormalize
    sim.P[:,:] = np.exp(sim.p)
    sim.Y[:,:] = np.exp(sim.y)
    sim.M[:,:] = sim.m*sim.P
    sim.C[:,:] = sim.c*sim.P
    sim.A[:,:] = sim.a*sim.P
    sim.H[:,:] = sim.h*sim.P

    return sol, sim