def build(path, is_virtual=False):
    # create world
    world = World(path, thread_num=args.thread)

    # create agents
    agents = []
    for i in world.intersections:
        action_space = gym.spaces.Discrete(len(i.phases))
        agent = PressLightAgent(
            action_space,
            LaneVehicleGenerator(world, i, ["lane_count"], in_only=True, average=None),
            LaneVehicleGenerator(world, i, ["lane_waiting_count"], in_only=True, average="all", negative=True),
            i.id,
            world,
            is_virtual
        )
        agent.epsilon = agent.epsilon_min
        agents.append(agent)
        if args.load_model:
            agents[-1].load_model(args.save_dir)
        # if len(agents) == 5:
        #     break
    # print(agents[0].ob_length)
    # print(agents[0].action_space)

    # create metric
    metric = TravelTimeMetric(world)

    # create env
    env = TSCEnv(world, agents, metric)
    return world, agents, env
示例#2
0
def init(args, test=False):
    tf_mute_warning()
    args.save_dir = save_dir + args.config_file[7:-5]
    if test:
        args.save_dir = save_dir + args.config_file[7:-10]

    # config_name = args.config_file.split('/')[1].split('.')[0]
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    logger = logging.getLogger('main')
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(
        os.path.join(args.log_dir,
                     datetime.now().strftime('%Y%m%d-%H%M%S') + ".log"))
    fh.setLevel(logging.DEBUG)
    sh = logging.StreamHandler()
    sh.setLevel(logging.INFO)
    logger.addHandler(fh)
    logger.addHandler(sh)

    # create world
    world = World(args.config_file, thread_num=args.thread, silent=True)

    # create agents
    agents = []
    for i in world.intersections:
        action_space = gym.spaces.Discrete(len(i.phases))
        agents.append(
            PressLightAgent(
                action_space,
                LaneVehicleGenerator(world,
                                     i, ["lane_count"],
                                     in_only=True,
                                     average=None),
                LaneVehicleGenerator(world,
                                     i, ["lane_waiting_count"],
                                     in_only=True,
                                     average="all",
                                     negative=True), i.id, world))
        if args.load_model:
            agents[-1].load_model(args.save_dir)
    # print(agents[0].ob_length)
    # print(agents[0].action_space)

    # create metric
    metric = TravelTimeMetric(world)

    # create env
    env = TSCEnv(world, agents, metric)

    return env
示例#3
0
logger.addHandler(sh)

# create world
world = World(args.config_file, thread_num=args.thread)

# create agents
agents = []
for i in world.intersections:
    action_space = gym.spaces.Discrete(len(i.phases))
    agents.append(
        PressLightAgent(
            action_space,
            LaneVehicleGenerator(world,
                                 i, ["lane_count"],
                                 in_only=True,
                                 average=None),
            LaneVehicleGenerator(world,
                                 i, ["lane_waiting_count"],
                                 in_only=True,
                                 average="all",
                                 negative=True), i.id, world))
    if args.load_model:
        agents[-1].load_model(args.save_dir)
print(agents[0].ob_length)
print(agents[0].action_space)

# create metric
metric = TravelTimeMetric(world)

# create env
env = TSCEnv(world, agents, metric)