コード例 #1
0
def run_agent(par_list, trials, T, ns, na, nr, nc, deval=False, ESS=None):

    #set parameters:
    #learn_pol: initial concentration paramter for policy prior
    #trans_prob: reward probability
    #avg: True for average action selection, False for maximum selection
    #Rho: Environment's reward generation probabilities as a function of time
    #utility: goal prior, preference p(o)
    learn_pol, trans_prob, avg, Rho, utility = par_list
    """
    create matrices
    """

    #generating probability of observations in each state
    A = np.eye(ns)

    #state transition generative probability (matrix)
    B = np.zeros((ns, ns, na))

    for i in range(0, na):
        B[i + 1, :, i] += 1

    # agent's beliefs about reward generation

    # concentration parameters
    C_alphas = np.ones((nr, ns, nc))
    # initialize state in front of levers so that agent knows it yields no reward
    C_alphas[0, 0, :] = 100
    for i in range(1, nr):
        C_alphas[i, 0, :] = 1

    # agent's initial estimate of reward generation probability
    C_agent = np.zeros((nr, ns, nc))
    for c in range(nc):
        C_agent[:, :,
                c] = np.array([(C_alphas[:, i, c]) / (C_alphas[:, i, c]).sum()
                               for i in range(ns)]).T

    # context transition matrix

    p = trans_prob
    q = 1. - p
    transition_matrix_context = np.zeros((nc, nc))
    transition_matrix_context += q / (nc - 1)
    for i in range(nc):
        transition_matrix_context[i, i] = p
    """
    create environment (grid world)
    """

    environment = env.MultiArmedBandid(A, B, Rho, trials=trials, T=T)
    """
    create policies
    """

    pol = np.array(list(itertools.product(list(range(na)), repeat=T - 1)))

    npi = pol.shape[0]

    # concentration parameters
    alphas = np.zeros((npi, nc)) + learn_pol

    prior_pi = alphas / alphas.sum(axis=0)
    """
    set state prior (where agent thinks it starts)
    """

    state_prior = np.zeros((ns))

    state_prior[0] = 1.
    """
    set action selection method
    """

    if ESS is not None:

        ac_sel = asl.DirichletSelector(trials=trials,
                                       T=T,
                                       number_of_actions=na)

    elif avg:

        ac_sel = asl.AveragedSelector(trials=trials, T=T, number_of_actions=na)

    else:

        ac_sel = asl.MaxSelector(trials=trials, T=T, number_of_actions=na)
    """
    set context prior
    """

    prior_context = np.zeros((nc)) + 0.1 / (nc - 1)
    prior_context[0] = 0.9
    """
    set up agent
    """

    # perception
    bayes_prc = prc.HierarchicalPerception(A,
                                           B,
                                           C_agent,
                                           transition_matrix_context,
                                           state_prior,
                                           utility,
                                           prior_pi,
                                           alphas,
                                           C_alphas,
                                           T=T)

    # agent
    bayes_pln = agt.BayesianPlanner(
        bayes_prc,
        ac_sel,
        pol,
        trials=trials,
        T=T,
        prior_states=state_prior,
        prior_policies=prior_pi,
        number_of_states=ns,
        prior_context=prior_context,
        learn_habit=True,
        learn_rew=True,
        #save_everything = True,
        number_of_policies=npi,
        number_of_rewards=nr)
    """
    create world
    """

    w = world.World(environment, bayes_pln, trials=trials, T=T)
    """
    simulate experiment
    """
    if not deval:
        w.simulate_experiment(range(trials))

    else:
        w.simulate_experiment(range(trials // 2))
        # reset utility to implement devaluation
        ut = utility[1:].sum()
        bayes_prc.prior_rewards[2:] = ut / (nr - 2)
        bayes_prc.prior_rewards[:2] = (1 - ut) / 2

        w.simulate_experiment(range(trials // 2, trials))

    return w
コード例 #2
0
def run_agent(par_list, w_old, trials=trials):
    
    #set parameters:
    #obs_unc: observation uncertainty condition
    #state_unc: state transition uncertainty condition
    #goal_pol: evaluate only policies that lead to the goal
    #utility: goal prior, preference p(o)
    trans_prob, avg, Rho = par_list
    
    
    """
    create matrices
    """
    ns = w_old.environment.Theta.shape[0]
    nr = w_old.environment.Rho.shape[1]
    na = w_old.environment.Theta.shape[2]
    T = w_old.T
    utility = w_old.agent.perception.prior_rewards.copy()
    
    #generating probability of observations in each state

    A = np.eye(ns)
        
    
    #state transition generative probability (matrix)
    B = np.zeros((ns, ns, na))
    
    for i in range(0,na):
        B[i+1,:,i] += 1
    
    # create reward generation
#            
#    C = np.zeros((utility.shape[0], ns))
#    
#    vals = np.array([0., 1./5., 0.95, 1./5., 1/5., 1./5.])
#    
#    for i in range(ns):
#        C[:,i] = [1-vals[i],vals[i]]
#    
#    changes = np.array([0.01, -0.01])
#    Rho = generate_bandit_timeseries(C, nb, trials, changes)
            
    # agent's beliefs about reward generation
    
    C_alphas = w_old.agent.perception.dirichlet_rew_params.copy()
    
    C_agent = w_old.agent.perception.generative_model_rewards.copy()
    #np.array([np.random.dirichlet(C_alphas[:,i]) for i in range(ns)]).T
    
    # context transition matrix
    
    transition_matrix_context = w_old.agent.perception.transition_matrix_context.copy()
                            
    """
    create environment (grid world)
    """
    
    environment = env.MultiArmedBandid(A, B, Rho, trials = trials, T = T)
    
    
    """
    create policies
    """
    
    pol = w_old.agent.policies
    
    #pol = pol[-2:]
    npi = pol.shape[0]
    
    # prior over policies

    #prior_pi[170] = 1. - 1e-3
    alphas = w_old.agent.perception.dirichlet_pol_params.copy()
#    for i in range(nb):
#        alphas[i+1,i] = 100
    #alphas[170] = 100
    prior_pi = np.exp(scs.digamma(alphas) - scs.digamma(alphas.sum(axis=0))[np.newaxis,:])
    prior_pi /= prior_pi.sum(axis=0)
    
    
    """
    set state prior (where agent thinks it starts)
    """
    
    state_prior = np.zeros((ns))
    
    state_prior[0] = 1.

    """
    set action selection method
    """

    if avg:
    
        ac_sel = asl.AveragedSelector(trials = trials, T = T, 
                                      number_of_actions = na)
    else:
        
        ac_sel = asl.MaxSelector(trials = trials, T = T, 
                                      number_of_actions = na)
    
#    ac_sel = asl.AveragedPolicySelector(trials = trials, T = T, 
#                                        number_of_policies = npi,
#                                        number_of_actions = na)
    
    prior_context = np.zeros((nc)) + 1./(nc)#np.dot(transition_matrix_context, w_old.agent.posterior_context[-1,-1])
        
#    prior_context[0] = 1.
    
    """
    set up agent
    """
        
    pol_par = alphas

    # perception
    bayes_prc = prc.HierarchicalPerception(A, B, C_agent, transition_matrix_context, state_prior, utility, prior_pi, pol_par, C_alphas, T=T)
    
    bayes_pln = agt.BayesianPlanner(bayes_prc, ac_sel, pol,
                      trials = trials, T = T,
                      prior_states = state_prior,
                      prior_policies = prior_pi,
                      number_of_states = ns, 
                      prior_context = prior_context,
                      learn_habit = True,
                      #save_everything = True,
                      number_of_policies = npi,
                      number_of_rewards = nr)
    

    """
    create world
    """
    
    w = world.World(environment, bayes_pln, trials = trials, T = T)
    
    """
    simulate experiment
    """
    
    w.simulate_experiment(range(trials))
    
    
    """
    plot and evaluate results
    """
#    plt.figure()
#    
#    for i in range(ns):
#        plt.plot(w.environment.Rho[:,0,i], label=str(i))
#        
#    plt.legend()
#    plt.show()
#    
#    print("won:", int(w.rewards.sum()/trials*100), "%")
#    
#    stayed = np.array([((w.actions[i,0] - w.actions[i+1,0])==0) for i in range(trials-1)])
#    
#    print("stayed:", int(stayed.sum()/trials*100), "%")
    
    return w
コード例 #3
0
    def setup_agent(self, w, first_trial=0, test_trials=None):

        ns = w.environment.Theta.shape[0]
        nr = w.environment.Rho.shape[1]
        na = w.environment.Theta.shape[2]
        nc = w.agent.perception.generative_model_rewards.shape[2]
        T = w.T
        trials = w.trials
        observations = w.observations.copy()
        rewards = w.rewards.copy()
        actions = w.actions.copy()
        utility = w.agent.perception.prior_rewards.copy()
        A = w.agent.perception.generative_model_observations.copy()
        B = w.agent.perception.generative_model_states.copy()

        if test_trials is None:
            test_trials = np.arange(0, trials, 1, dtype=int)

        transition_matrix_context = w.agent.perception.transition_matrix_context.copy(
        )

        # concentration parameters
        C_alphas = np.ones((nr, ns, nc))
        # initialize state in front of levers so that agent knows it yields no reward
        C_alphas[0, 0, :] = 100
        for i in range(1, nr):
            C_alphas[i, 0, :] = 1

        # agent's initial estimate of reward generation probability
        C_agent = np.zeros((nr, ns, nc))
        for c in range(nc):
            C_agent[:, :, c] = np.array([
                (C_alphas[:, i, c]) / (C_alphas[:, i, c]).sum()
                for i in range(ns)
            ]).T

        pol = w.agent.policies.copy()

        #pol = pol[-2:]
        npi = pol.shape[0]

        # prior over policies

        alpha = 1
        alphas = np.zeros_like(
            w.agent.perception.dirichlet_pol_params.copy()) + alpha

        prior_pi = alphas.copy()
        prior_pi /= prior_pi.sum(axis=0)

        state_prior = np.zeros((ns))

        state_prior[0] = 1.

        prior_context = np.zeros((nc)) + 1. / (
            nc
        )  #np.dot(transition_matrix_context, w.agent.posterior_context[-1,-1])

        #    prior_context[0] = 1.

        pol_par = alphas

        # perception
        bayes_prc = prc.HierarchicalPerception(A,
                                               B,
                                               C_agent,
                                               transition_matrix_context,
                                               state_prior,
                                               utility,
                                               prior_pi,
                                               pol_par,
                                               C_alphas,
                                               T=T)

        bayes_pln = agt.BayesianPlanner(
            bayes_prc,
            None,
            pol,
            trials=trials,
            T=T,
            prior_states=state_prior,
            prior_policies=prior_pi,
            number_of_states=ns,
            prior_context=prior_context,
            learn_habit=True,
            #save_everything = True,
            number_of_policies=npi,
            number_of_rewards=nr)

        self.agent = world.FakeWorld(bayes_pln,
                                     observations,
                                     rewards,
                                     actions,
                                     trials=trials,
                                     T=T)

        self.fixed = {'rew_mod': C_agent, 'beta_rew': C_alphas}

        self.likelihood = np.zeros((self.nruns, len(self.sample_space)),
                                   dtype=np.float64)

        for i in range(self.nruns):
            print("precalculating likelihood run ", i)
            for j, h in enumerate(self.sample_space):
                alpha = 1. / h
                self.likelihood[i,j] \
                    = self.agent.fit_model(alpha, self.fixed, test_trials)
コード例 #4
0
def run_agent(par_list, trials, T, ns, na, nr, nc, f, contexts, states, \
              state_trans=None, correct_choice=None, congruent=None,\
              num_in_run=None, random_draw=False, pol_lambda=0, r_lambda=0,
              one_context=False):
    #set parameters:
    #learn_pol: initial concentration paramter for policy prior
    #trans_prob: reward probability
    #avg: True for average action selection, False for maximum selection
    #Rho: Environment's reward generation probabilities as a function of time
    #utility: goal prior, preference p(o)
    learn_pol, trans_prob, Rho, utility, unc = par_list


    """
    create matrices
    """


    #generating probability of observations in each state
    A = np.eye(ns)


    #state transition generative probability (matrix)
    if state_trans is None:
        B = np.zeros((ns, ns, na))

        for i in range(0,na):
            B[i+1,:,i] += 1
    else:
        B = state_trans.copy()

    # agent's beliefs about reward generation

    # concentration parameters
    C_alphas = np.ones((nr, ns, nc))
    # initialize state in front of levers so that agent knows it yields no reward
    C_alphas[:,:4,:] = np.array([100,1])[:,None,None]
    # C_alphas[:,4:,0] = np.array([[1, 2],print(self.Rho.shape)
    #                               [2, 1]])
    # C_alphas[:,4:,1] = np.array([[2, 1],
    #                               [1, 2]])

    # agent's initial estimate of reward generation probability
    C_agent = np.zeros((nr, ns, nc))
    for c in range(nc):
        C_agent[:,:,c] = np.array([(C_alphas[:,i,c])/(C_alphas[:,i,c]).sum() for i in range(ns)]).T


    # context transition matrix

    if nc>1:
        p = trans_prob
        q = 1.-p
        transition_matrix_context = np.zeros((nc, nc))
        transition_matrix_context += q/(nc-1)
        for i in range(nc):
            transition_matrix_context[i,i] = p
    else:
        transition_matrix_context = np.array([[1]])

    # context observation matrix

    if nc > 1:
        D = np.zeros((nc,nc)) + unc
        for c in range(nc):
            D[c,c] = 1-(unc*(nc-1))
    else:
        D = np.array([[1]])

    """
    create environment (grid world)
    """
    if not one_context:
        
        environment = env.TaskSwitching(A, B, Rho, D, states, contexts, \
                                    trials = trials, T = T,\
                                    correct_choice=correct_choice, \
                                    congruent=congruent, \
                                    num_in_run=num_in_run)
            
    else:
        
        environment = env.TaskSwitchingOneConext(A, B, Rho, D, states, contexts, \
                                    trials = trials, T = T,\
                                    correct_choice=correct_choice, \
                                    congruent=congruent, \
                                    num_in_run=num_in_run)
        


    """
    create policies
    """

    pol = np.array(list(itertools.product(list(range(na)), repeat=T-1)))

    npi = pol.shape[0]

    # concentration parameters
    alphas = np.zeros((npi, nc)) + learn_pol

    prior_pi = alphas / alphas.sum(axis=0)


    """
    set state prior (where agent thinks it starts)
    """

    state_prior = np.zeros((ns))

    state_prior[:4] = 1./4

    """
    set action selection method
    """

    ac_sel = asl.DirichletSelector(trials=trials, T=T, number_of_actions=na, factor=f, calc_dkl=False, calc_entropy=False, draw_true_post=random_draw)

    """
    set context prior
    """

    if nc > 1:
        prior_context = np.zeros((nc)) + 0.1/(nc-1)
        prior_context[0] = 0.9
    else:
        prior_context = np.array([1])

    """
    set up agent
    """

    # perception
    bayes_prc = prc.HierarchicalPerception(A, B, C_agent, transition_matrix_context, 
                                        state_prior, utility, prior_pi, alphas, 
                                        C_alphas, T=T, generative_model_context=D, 
                                        pol_lambda=pol_lambda, r_lambda=r_lambda,
                                        non_decaying=4)

    # agent
    bayes_pln = agt.BayesianPlanner(bayes_prc, ac_sel, pol,
                      trials = trials, T = T,
                      prior_states = state_prior,
                      prior_policies = prior_pi,
                      number_of_states = ns,
                      prior_context = prior_context,
                      learn_habit = True,
                      learn_rew = True,
                      #save_everything = True,
                      number_of_policies = npi,
                      number_of_rewards = nr)


    """
    create world
    """

    w = world.World(environment, bayes_pln, trials = trials, T = T)

    """
    simulate experiment
    """
    w.simulate_experiment(range(trials))

    return w
コード例 #5
0
def run_agent(par_list, trials=trials, T=T, ns=ns, na=na):

    #set parameters:
    #obs_unc: observation uncertainty condition
    #state_unc: state transition uncertainty condition
    #goal_pol: evaluate only policies that lead to the goal
    #utility: goal prior, preference p(o)
    learn_pol, avg, Rho, learn_habit, utility = par_list
    learn_rew = 1

    """
    create matrices
    """


    #generating probability of observations in each state
    A = np.eye(no)


    #state transition generative probability (matrix)
    B = np.zeros((ns, ns, na))
    b1 = 0.7
    nb1 = 1.-b1
    b2 = 0.7
    nb2 = 1.-b2

    B[:,:,0] = np.array([[  0,  0,  0,  0,  0,  0,  0,],
                         [ b1,  0,  0,  0,  0,  0,  0,],
                         [nb1,  0,  0,  0,  0,  0,  0,],
                         [  0,  1,  0,  1,  0,  0,  0,],
                         [  0,  0,  1,  0,  1,  0,  0,],
                         [  0,  0,  0,  0,  0,  1,  0,],
                         [  0,  0,  0,  0,  0,  0,  1,],])

    B[:,:,1] = np.array([[  0,  0,  0,  0,  0,  0,  0,],
                         [nb2,  0,  0,  0,  0,  0,  0,],
                         [ b2,  0,  0,  0,  0,  0,  0,],
                         [  0,  0,  0,  1,  0,  0,  0,],
                         [  0,  0,  0,  0,  1,  0,  0,],
                         [  0,  1,  0,  0,  0,  1,  0,],
                         [  0,  0,  1,  0,  0,  0,  1,],])

    # create reward generation
#
#    C = np.zeros((utility.shape[0], ns))
#
#    vals = np.array([0., 1./5., 0.95, 1./5., 1/5., 1./5.])
#
#    for i in range(ns):
#        C[:,i] = [1-vals[i],vals[i]]
#
#    changes = np.array([0.01, -0.01])
#    Rho = generate_bandit_timeseries(C, nb, trials, changes)

    # agent's beliefs about reward generation

    C_alphas = np.zeros((nr, ns, nc)) + learn_rew
    C_alphas[0,:3,:] = 100
    for i in range(1,nr):
        C_alphas[i,0,:] = 1
#    C_alphas[0,1:,:] = 100
#    for c in range(nb):
#        C_alphas[1,c+1,c] = 100
#        C_alphas[0,c+1,c] = 1
    #C_alphas[:,13] = [100, 1]

    C_agent = np.zeros((nr, ns, nc))
    for c in range(nc):
        C_agent[:,:,c] = np.array([(C_alphas[:,i,c])/(C_alphas[:,i,c]).sum() for i in range(ns)]).T
    #np.array([np.random.dirichlet(C_alphas[:,i]) for i in range(ns)]).T

    # context transition matrix

    transition_matrix_context = np.ones(1)

    """
    create environment (grid world)
    """

    environment = env.MultiArmedBandid(A, B, Rho, trials = trials, T = T)


    """
    create policies
    """

    pol = np.array(list(itertools.product(list(range(na)), repeat=T-1)))

    #pol = pol[-2:]
    npi = pol.shape[0]

    # prior over policies

    prior_pi = np.ones(npi)/npi #np.zeros(npi) + 1e-3/(npi-1)
    #prior_pi[170] = 1. - 1e-3
    alphas = np.zeros((npi, nc)) + learn_pol
#    for i in range(nb):
#        alphas[i+1,i] = 100
    #alphas[170] = 100
    prior_pi = alphas / alphas.sum(axis=0)


    """
    set state prior (where agent thinks it starts)
    """

    state_prior = np.zeros((ns))

    state_prior[0] = 1.

    """
    set action selection method
    """

    if avg:

        sel = 'avg'

        ac_sel = asl.AveragedSelector(trials = trials, T = T,
                                      number_of_actions = na)
    else:

        sel = 'max'

        ac_sel = asl.MaxSelector(trials = trials, T = T,
                                      number_of_actions = na)

#    ac_sel = asl.AveragedPolicySelector(trials = trials, T = T,
#                                        number_of_policies = npi,
#                                        number_of_actions = na)

    prior_context = np.array([1.])

#    prior_context[0] = 1.

    """
    set up agent
    """
    #bethe agent
    if agent == 'bethe':

        agnt = 'bethe'

        pol_par = alphas

        # perception
        bayes_prc = prc.HierarchicalPerception(A, B, C_agent, transition_matrix_context, 
                                          state_prior, utility, prior_pi, 
                                          pol_par, C_alphas, T=T,
                                          pol_lambda=0.3, r_lambda=0.6,
                                          non_decaying=3, dec_temp=4.)

        bayes_pln = agt.BayesianPlanner(bayes_prc, ac_sel, pol,
                          trials = trials, T = T,
                          prior_states = state_prior,
                          prior_policies = prior_pi,
                          number_of_states = ns,
                          prior_context = prior_context,
                          learn_habit = learn_habit,
                          learn_rew=True,
                          #save_everything = True,
                          number_of_policies = npi,
                          number_of_rewards = nr)
    #MF agent
    else:

        agnt = 'mf'

        bayes_prc = prc.MFPerception(A, B, utility, state_prior, T = T)



        bayes_pln = agt.BayesianMFPlanner(bayes_prc, [], ac_sel,
                                  trials = trials, T = T,
                                  prior_states = state_prior,
                                  policies = pol,
                                  number_of_states = ns,
                                  number_of_policies = npi)


    """
    create world
    """

    w = world.World(environment, bayes_pln, trials = trials, T = T)

    """
    simulate experiment
    """

#    w.simulate_experiment(range(trials-100))
#    new_ut = utility.copy()
#    new_ut[1] = utility[0]
#    new_ut /= new_ut.sum()
#    w.agent.perception.reset_preferences(0,new_ut, pol)
#    w.simulate_experiment(range(trials-100, trials))

    w.simulate_experiment(range(trials))


    """
    plot and evaluate results
    """
#    plt.figure()
#
#    for i in range(3,ns):
#        plt.plot(w.environment.Rho[:,1,i], label=str(i))
#
#    plt.ylim([0,1])
#    plt.legend()
#    plt.show()
#
#
#    rewarded = np.where(w.rewards[:trials-1,-1] == 1)[0]
#    unrewarded = np.where(w.rewards[:trials-1,-1] == 0)[0]
#
#    rare = np.append(np.where(w.environment.hidden_states[np.where(w.actions[:,0] == 0)[0]] == 2)[0],
#                     np.where(w.environment.hidden_states[np.where(w.actions[:,0] == 1)[0]] == 1)[0])
#
#    common = np.append(np.where(w.environment.hidden_states[np.where(w.actions[:,0] == 0)[0]] == 1)[0],
#                     np.where(w.environment.hidden_states[np.where(w.actions[:,0] == 1)[0]] == 2)[0])
#
#    names = ["rewarded common", "rewarded rare", "unrewarded common", "unrewarded rare"]
#
#    index_list = [np.intersect1d(rewarded, common), np.intersect1d(rewarded, rare),
#                 np.intersect1d(unrewarded, common), np.intersect1d(unrewarded, rare)]
#
#    stayed_list = [((w.actions[index_list[i],0] - w.actions[index_list[i]+1,0])==0).sum()/len(index_list[i]) for i in range(4)]
#
##    stayed_rew = ((w.actions[rewarded,0] - w.actions[rewarded+1,0]) == 0).sum()/len(rewarded)
##
##    stayed_unrew = ((w.actions[unrewarded,0] - w.actions[unrewarded+1,0]) == 0).sum()/len(unrewarded)
#
#    plt.figure()
#    plt.bar(x=names,height=stayed_list)
#    plt.show()


    return w
コード例 #6
0
def run_agent(par_list, trials=trials, T=T, Lx=Lx, Ly=Ly, ns=ns, na=na):

    #set parameters:
    #obs_unc: observation uncertainty condition
    #state_unc: state transition uncertainty condition
    #goal_pol: evaluate only policies that lead to the goal
    #utility: goal prior, preference p(o)
    obs_unc, state_unc, goal_pol, avg, context, utility, h, q = par_list
    """
    create matrices
    """

    vals = np.array([1., 2 / 3., 1 / 2., 1. / 2.])

    #generating probability of observations in each state
    A = np.eye(ns) + const
    np.fill_diagonal(A, 1 - (ns - 1) * const)

    #generate horizontal gradient for observation uncertainty condition
    # if obs_unc:

    #     condition = 'obs'

    #     for s in range(ns):
    #         x = s//Ly
    #         y = s%Ly

    #         c = 1#vals[L - y - 1]

    #         # look for neighbors
    #         neighbors = []
    #         if (s-4)>=0 and (s-4)!=g1:
    #             neighbors.append(s-4)

    #         if (s%4)!=0 and (s-1)!=g1:
    #             neighbors.append(s-1)

    #         if (s+4)<=(ns-1) and (s+4)!=g1:
    #             neighbors.append(s+4)

    #         if ((s+1)%4)!=0 and (s+1)!=g1:
    #             neighbors.append(s+1)

    #         A[s,s] = c
    #         for n in neighbors:
    #             A[n,s] = (1-c)/len(neighbors)

    #state transition generative probability (matrix)
    B = np.zeros((ns, ns, na)) + const

    cert_arr = np.zeros(ns)
    for s in range(ns):
        x = s // Ly
        y = s % Ly

        #state uncertainty condition
        if state_unc:
            if (x == 0) or (y == 3):
                c = vals[0]
            elif (x == 1) or (y == 2):
                c = vals[1]
            elif (x == 2) or (y == 1):
                c = vals[2]
            else:
                c = vals[3]

            condition = 'state'

        else:
            c = 1.

        cert_arr[s] = c
        for u in range(na):
            x = s // Ly + actions[u][0]
            y = s % Ly + actions[u][1]

            #check if state goes over boundary
            if x < 0:
                x = 0
            elif x == Lx:
                x = Lx - 1

            if y < 0:
                y = 0
            elif y == Ly:
                y = Ly - 1

            s_new = Ly * x + y
            if s_new == s:
                B[s, s, u] = 1 - (ns - 1) * const
            else:
                B[s, s, u] = 1 - c + const
                B[s_new, s, u] = c - (ns - 1) * const

    B_c = np.broadcast_to(B[:, :, :, np.newaxis], (ns, ns, na, nc))
    print(B.shape)
    """
    create environment (grid world)
    """
    Rho = np.zeros((nr, ns)) + const
    Rho[0, :] = 1 - (nr - 1) * const
    Rho[:, np.argmax(utility)] = [0 + const, 1 - (nr - 1) * const]
    print(Rho)
    util = np.array([1 - np.amax(utility), np.amax(utility)])

    environment = env.GridWorld(A,
                                B,
                                Rho,
                                trials=trials,
                                T=T,
                                initial_state=start)

    Rho_agent = np.ones((nr, ns, nc)) / nr

    if True:
        templates = np.ones_like(Rho_agent)
        templates[0] *= 100
        assert ns == nc
        for s in range(ns):
            templates[0, s, s] = 1
            templates[1, s, s] = 100
        dirichlet_rew_params = templates
    else:
        dirichlet_rew_params = np.ones_like(Rho_agent)
    """
    create policies
    """

    if goal_pol:
        pol = []
        su = 3
        for p in itertools.product([0, 1], repeat=T - 1):
            if (np.array(p)[0:6].sum() == su) and (np.array(p)[-1] != 1):
                pol.append(list(p))

        pol = np.array(pol) + 2
    else:
        pol = np.array(list(itertools.product(list(range(na)), repeat=T - 1)))

    #pol = pol[np.where(pol[:,0]>1)]

    npi = pol.shape[0]

    prior_policies = np.ones((npi, nc)) / npi
    dirichlet_pol_param = np.zeros_like(prior_policies) + h
    """
    set state prior (where agent thinks it starts)
    """

    state_prior = np.zeros((ns))

    # state_prior[0] = 1./4.
    # state_prior[1] = 1./4.
    # state_prior[4] = 1./4.
    # state_prior[5] = 1./4.
    state_prior[start] = 1
    """
    set context prior and matrix
    """

    context_prior = np.ones(nc)
    trans_matrix_context = np.ones((nc, nc))
    if nc > 1:
        # context_prior[0] = 0.9
        # context_prior[1:] = 0.1 / (nc-1)
        context_prior /= nc
        trans_matrix_context[:] = (1 - q) / (nc - 1)
        np.fill_diagonal(trans_matrix_context, q)
    """
    set action selection method
    """

    if avg:

        sel = 'avg'

        ac_sel = asl.DirichletSelector(trials=trials,
                                       T=T,
                                       factor=0.5,
                                       number_of_actions=na,
                                       calc_entropy=False,
                                       calc_dkl=False,
                                       draw_true_post=True)
    else:

        sel = 'max'

        ac_sel = asl.MaxSelector(trials=trials, T=T, number_of_actions=na)


#    ac_sel = asl.AveragedPolicySelector(trials = trials, T = T,
#                                        number_of_policies = npi,
#                                        number_of_actions = na)
    """
    set up agent
    """
    #bethe agent
    if agent == 'bethe':

        agnt = 'bethe'

        # perception and planning

        bayes_prc = prc.HierarchicalPerception(
            A,
            B_c,
            Rho_agent,
            trans_matrix_context,
            state_prior,
            util,
            prior_policies,
            dirichlet_pol_params=dirichlet_pol_param,
            dirichlet_rew_params=dirichlet_rew_params)

        bayes_pln = agt.BayesianPlanner(
            bayes_prc,
            ac_sel,
            pol,
            trials=trials,
            T=T,
            prior_states=state_prior,
            prior_policies=prior_policies,
            prior_context=context_prior,
            number_of_states=ns,
            learn_habit=True,
            learn_rew=True,
            #save_everything = True,
            number_of_policies=npi,
            number_of_rewards=nr)
    #MF agent
    else:

        agnt = 'mf'

        # perception and planning

        bayes_prc = prc.MFPerception(A, B, state_prior, utility, T=T)

        bayes_pln = agt.BayesianMFPlanner(bayes_prc, [],
                                          ac_sel,
                                          trials=trials,
                                          T=T,
                                          prior_states=state_prior,
                                          policies=pol,
                                          number_of_states=ns,
                                          number_of_policies=npi)
    """
    create world
    """

    w = world.World(environment, bayes_pln, trials=trials, T=T)
    """
    simulate experiment
    """

    if not context:
        w.simulate_experiment()
    else:
        w.simulate_experiment(curr_trials=range(0, trials // 2))
        Rho_new = np.zeros((nr, ns)) + const
        Rho_new[0, :] = 1 - (nr - 1) * const
        Rho_new[:, g2] = [0 + const, 1 - (nr - 1) * const]
        print(Rho_new)
        w.environment.Rho[:] = Rho_new
        #w.agent.perception.generative_model_rewards = Rho_new
        w.simulate_experiment(curr_trials=range(trials // 2, trials))
    """
    plot and evaluate results
    """
    #find successful and unsuccessful runs
    #goal = np.argmax(utility)
    successfull_g1 = np.where(environment.hidden_states[:, -1] == g1)[0]
    if context:
        successfull_g2 = np.where(environment.hidden_states[:, -1] == g2)[0]
        unsuccessfull1 = np.where(environment.hidden_states[:, -1] != g1)[0]
        unsuccessfull2 = np.where(environment.hidden_states[:, -1] != g2)[0]
        unsuccessfull = np.intersect1d(unsuccessfull1, unsuccessfull2)
    else:
        unsuccessfull = np.where(environment.hidden_states[:, -1] != g1)[0]

    #total  = len(successfull)

    #plot start and goal state
    start_goal = np.zeros((Lx, Ly))

    x_y_start = (start // Ly, start % Ly)
    start_goal[x_y_start] = 1.
    x_y_g1 = (g1 // Ly, g1 % Ly)
    start_goal[x_y_g1] = -1.
    x_y_g2 = (g2 // Ly, g2 % Ly)
    start_goal[x_y_g2] = -2.

    palette = [(159 / 255, 188 / 255, 147 / 255),
               (135 / 255, 170 / 255, 222 / 255),
               (242 / 255, 241 / 255, 241 / 255),
               (242 / 255, 241 / 255, 241 / 255),
               (199 / 255, 174 / 255, 147 / 255),
               (199 / 255, 174 / 255, 147 / 255)]

    #set up figure params
    # ~ factor = 3
    # ~ grid_plot_kwargs = {'vmin': -2, 'vmax': 2, 'center': 0, 'linecolor': '#D3D3D3',
    # ~ 'linewidths': 7, 'alpha': 1, 'xticklabels': False,
    # ~ 'yticklabels': False, 'cbar': False,
    # ~ 'cmap': palette}#sns.diverging_palette(120, 45, as_cmap=True)} #"RdBu_r",

    # ~ # plot grid
    # ~ fig = plt.figure(figsize=[factor*5,factor*4])

    # ~ ax = fig.gca()

    # ~ annot = np.zeros((Lx,Ly))
    # ~ for i in range(Lx):
    # ~ for j in range(Ly):
    # ~ annot[i,j] = i*Ly+j

    # ~ u = sns.heatmap(start_goal, ax = ax, **grid_plot_kwargs, annot=annot, annot_kws={"fontsize": 40})
    # ~ ax.invert_yaxis()
    # ~ plt.savefig('grid.svg', dpi=600)
    # ~ #plt.show()

    # ~ # set up paths figure
    # ~ fig = plt.figure(figsize=[factor*5,factor*4])

    # ~ ax = fig.gca()

    # ~ u = sns.heatmap(start_goal, zorder=2, ax = ax, **grid_plot_kwargs)
    # ~ ax.invert_yaxis()

    # ~ #find paths and count them
    # ~ n1 = np.zeros((ns, na))

    # ~ for i in successfull_g1:

    # ~ for j in range(T-1):
    # ~ d = environment.hidden_states[i, j+1] - environment.hidden_states[i, j]
    # ~ if d not in [1,-1,Ly,-Ly,0]:
    # ~ print("ERROR: beaming")
    # ~ if d == 1:
    # ~ n1[environment.hidden_states[i, j],0] +=1
    # ~ if d == -1:
    # ~ n1[environment.hidden_states[i, j]-1,0] +=1
    # ~ if d == Ly:
    # ~ n1[environment.hidden_states[i, j],1] +=1
    # ~ if d == -Ly:
    # ~ n1[environment.hidden_states[i, j]-Ly,1] +=1

    # ~ n2 = np.zeros((ns, na))

    # ~ if context:
    # ~ for i in successfull_g2:

    # ~ for j in range(T-1):
    # ~ d = environment.hidden_states[i, j+1] - environment.hidden_states[i, j]
    # ~ if d not in [1,-1,Ly,-Ly,0]:
    # ~ print("ERROR: beaming")
    # ~ if d == 1:
    # ~ n2[environment.hidden_states[i, j],0] +=1
    # ~ if d == -1:
    # ~ n2[environment.hidden_states[i, j]-1,0] +=1
    # ~ if d == Ly:
    # ~ n2[environment.hidden_states[i, j],1] +=1
    # ~ if d == -Ly:
    # ~ n2[environment.hidden_states[i, j]-Ly,1] +=1

    # ~ un = np.zeros((ns, na))

    # ~ for i in unsuccessfull:

    # ~ for j in range(T-1):
    # ~ d = environment.hidden_states[i, j+1] - environment.hidden_states[i, j]
    # ~ if d not in [1,-1,Ly,-Ly,0]:
    # ~ print("ERROR: beaming")
    # ~ if d == 1:
    # ~ un[environment.hidden_states[i, j],0] +=1
    # ~ if d == -1:
    # ~ un[environment.hidden_states[i, j]-1,0] +=1
    # ~ if d == Ly:
    # ~ un[environment.hidden_states[i, j],1] +=1
    # ~ if d == -Ly:
    # ~ un[environment.hidden_states[i, j]-4,1] +=1

    # ~ total_num = n1.sum() + n2.sum() + un.sum()

    # ~ if np.any(n1 > 0):
    # ~ n1 /= total_num

    # ~ if np.any(n2 > 0):
    # ~ n2 /= total_num

    # ~ if np.any(un > 0):
    # ~ un /= total_num

    # ~ #plotting
    # ~ for i in range(ns):

    # ~ x = [i%Ly + .5]
    # ~ y = [i//Ly + .5]

    # ~ #plot uncertainties
    # ~ if obs_unc:
    # ~ plt.plot(x,y, 'o', color=(219/256,122/256,147/256), markersize=factor*12/(A[i,i])**2, alpha=1.)
    # ~ if state_unc:
    # ~ plt.plot(x,y, 'o', color=(100/256,149/256,237/256), markersize=factor*12/(cert_arr[i])**2, alpha=1.)

    # ~ #plot unsuccessful paths
    # ~ for j in range(2):

    # ~ if un[i,j]>0.0:
    # ~ if j == 0:
    # ~ xp = x + [x[0] + 1]
    # ~ yp = y + [y[0] + 0]
    # ~ if j == 1:
    # ~ xp = x + [x[0] + 0]
    # ~ yp = y + [y[0] + 1]

    # ~ plt.plot(xp,yp, '-', color='#D5647C', linewidth=factor*75*un[i,j],
    # ~ zorder = 9, alpha=1)

    # ~ #set plot title
    # ~ #plt.title("Planning: successful "+str(round(100*total/trials))+"%", fontsize=factor*9)

    # ~ #plot successful paths on top
    # ~ for i in range(ns):

    # ~ x = [i%Ly + .5]
    # ~ y = [i//Ly + .5]

    # ~ for j in range(2):

    # ~ if n1[i,j]>0.0:
    # ~ if j == 0:
    # ~ xp = x + [x[0] + 1]
    # ~ yp = y + [y[0]]
    # ~ if j == 1:
    # ~ xp = x + [x[0] + 0]
    # ~ yp = y + [y[0] + 1]
    # ~ plt.plot(xp,yp, '-', color='#4682B4', linewidth=factor*75*n1[i,j],
    # ~ zorder = 10, alpha=1)

    # ~ #plot successful paths on top
    # ~ if context:
    # ~ for i in range(ns):

    # ~ x = [i%Ly + .5]
    # ~ y = [i//Ly + .5]

    # ~ for j in range(2):

    # ~ if n2[i,j]>0.0:
    # ~ if j == 0:
    # ~ xp = x + [x[0] + 1]
    # ~ yp = y + [y[0]]
    # ~ if j == 1:
    # ~ xp = x + [x[0] + 0]
    # ~ yp = y + [y[0] + 1]
    # ~ plt.plot(xp,yp, '-', color='#55ab75', linewidth=factor*75*n2[i,j],
    # ~ zorder = 10, alpha=1)

    # ~ #print("percent won", total/trials, "state prior", np.amax(utility))

    # ~ plt.savefig('chosen_paths_'+name_str+'h'+str(h)+'.svg')
    #plt.show()

    # max_RT = np.amax(w.agent.action_selection.RT[:,0])
    # plt.figure()
    # plt.plot(w.agent.action_selection.RT[:,0], '.')
    # plt.ylim([0,1.05*max_RT])
    # plt.xlim([0,trials])
    # plt.savefig("Gridworld_Dir_h"+str(h)+".svg")
    # plt.show()
    """
    save data
    """

    if save_data:
        jsonpickle_numpy.register_handlers()

        ut = np.amax(utility)
        p_o = '{:02d}'.format(round(ut * 10).astype(int))
        fname = agnt + '_' + condition + '_' + sel + '_initUnc_' + p_o + '.json'
        fname = os.path.join(data_folder, fname)
        pickled = pickle.encode(w)
        with open(fname, 'w') as outfile:
            json.dump(pickled, outfile)

    return w