예제 #1
0
def run(use_gui=True, runs=1):
    out_csv = 'outputs/double/sarsa-double'

    env = SumoEnvironment(net_file='nets/double/network.net.xml',
                          single_agent=False,
                          route_file='nets/double/flow.rou.xml',
                          out_csv_name=out_csv,
                          use_gui=use_gui,
                          num_seconds=86400,
                          yellow_time=3,
                          min_green=5,
                          max_green=60)

    fixed_tl = False
    agents = {
        ts_id: TrueOnlineSarsaLambda(env.observation_spaces(ts_id),
                                     env.action_spaces(ts_id),
                                     alpha=0.000000001,
                                     gamma=0.95,
                                     epsilon=0.05,
                                     lamb=0.1,
                                     fourier_order=7)
        for ts_id in env.ts_ids
    }

    for run in range(1, runs + 1):
        obs = env.reset()
        done = {'__all__': False}

        if fixed_tl:
            while not done['__all__']:
                _, _, done, _ = env.step(None)
        else:
            while not done['__all__']:
                actions = {
                    ts_id: agents[ts_id].act(obs[ts_id])
                    for ts_id in obs.keys()
                }

                next_obs, r, done, _ = env.step(action=actions)

                for ts_id in next_obs.keys():
                    agents[ts_id].learn(state=obs[ts_id],
                                        action=actions[ts_id],
                                        reward=r[ts_id],
                                        next_state=next_obs[ts_id],
                                        done=done[ts_id])
                    obs[ts_id] = next_obs[ts_id]

        env.save_csv(out_csv, run)
예제 #2
0
                                    '').replace('.net.xml', '')
    out_csv = f'outputs/5x5-Raphael/{scenario}_{experiment_time}_alpha{args.alpha}_gamma{args.gamma}_eps{args.epsilon}_decay{args.decay}'

    env = SumoEnvironment(net_file=args.network,
                          route_file=args.route,
                          out_csv_name=out_csv,
                          use_gui=args.gui,
                          num_seconds=args.seconds,
                          min_green=args.min_green,
                          max_green=args.max_green,
                          max_depart_delay=0)

    initial_states = env.reset()
    ql_agents = {
        ts: QLAgent(starting_state=env.encode(initial_states[ts], ts),
                    state_space=env.observation_spaces(ts),
                    action_space=env.action_spaces(ts),
                    alpha=args.alpha,
                    gamma=args.gamma,
                    exploration_strategy=EpsilonGreedy(
                        initial_epsilon=args.epsilon,
                        min_epsilon=args.min_epsilon,
                        decay=args.decay))
        for ts in env.ts_ids
    }
    infos = []
    done = {'__all__': False}
    if args.fixed:
        while not done['__all__']:
            _, _, done, _ = env.step({})
    else: