Esempio n. 1
0
    def __init__(self, args, base_policy="max_pressure", rollout_agent_num=0):
        self.args = args
        self.base_policy = base_policy
        self.tmp_archive_name = self.args.prefix + "_snapshot"
        tmp_world = World(args.config_file,
                          thread_num=args.thread,
                          silent=True)
        self.n_intersections = len(tmp_world.intersections)
        self.action_space = len(tmp_world.intersections[0].phases)
        self.rollout_agent_num = self.n_intersections if rollout_agent_num <= 0 else rollout_agent_num
        print("creating subprocesses....")
        self.action_queue = multiprocessing.Queue(self.action_space)
        self.result_queue = multiprocessing.Queue(self.action_space)
        tf_mute_warning()
        self.pool = []
        for process_id in range(self.action_space):
            process = RolloutProcess(process_id,
                                     self.args,
                                     self.action_queue,
                                     self.result_queue,
                                     policy=self.base_policy,
                                     archive_name=self.tmp_archive_name)
            self.pool.append(process)
        for p in self.pool:
            p.start()

        self.env = create_env(args, policy=base_policy)
Esempio n. 2
0
def init(args, test=False):
    tf_mute_warning()
    args.save_dir = save_dir + args.config_file[7:-5]
    if test:
        args.save_dir = save_dir + args.config_file[7:-10]

    # config_name = args.config_file.split('/')[1].split('.')[0]
    # args.agent_save_dir = args.save_dir + "/" + config_name
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    logger = logging.getLogger('main')
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(
        os.path.join(args.log_dir,
                     datetime.now().strftime('%Y%m%d-%H%M%S') + ".log"))
    fh.setLevel(logging.DEBUG)
    sh = logging.StreamHandler()
    sh.setLevel(logging.INFO)
    logger.addHandler(fh)
    logger.addHandler(sh)

    # create world
    world = World(args.config_file, thread_num=args.thread, silent=True)

    # create agents
    agents = []
    for i in world.intersections:
        action_space = gym.spaces.Discrete(len(i.phases))
        agents.append(
            DQNAgent(
                action_space,
                LaneVehicleGenerator(world,
                                     i, ["lane_count"],
                                     in_only=True,
                                     average=None),
                LaneVehicleGenerator(world,
                                     i, ["lane_waiting_count"],
                                     in_only=True,
                                     average="all",
                                     negative=True), i.id))
        if args.load_model:
            agents[-1].load_model(args.save_dir)
    if args.share_weights:
        model = agents[0].model
        for agent in agents:
            agent.model = model

    # create metric
    metric = TravelTimeMetric(world)

    # create env
    env = TSCEnv(world, agents, metric)

    return env
Esempio n. 3
0
def rollout(args):
    MODEL_PATH = args.load_dir
    MODEL_ID = args.model_id
    tf_mute_warning()
    traj_buffer_list = [
        TrajectoryBuffer(10000,
                         file_name="{}_headstart_{}_rollout".format(
                             args.prefix, args.base_policy))
    ]
    print("creating rollout controller...")
    rollout_controller = RolloutControllerParallel(
        args, base_policy=args.base_policy, rollout_agent_num=-1)
    # model_path, model_id = (model_path_22, id22) if rollout_controller.n_intersections < 10 else (model_path_44, id44)

    if not args.base_policy == "max_pressure":
        if args.base_policy == "frap":
            for agent in rollout_controller.env.agents:
                agent.load_model(dir=MODEL_PATH)
        elif args.parameter_sharing:
            rollout_controller.env.agents[0].load_model(dir=MODEL_PATH,
                                                        model_id=MODEL_ID)
        else:
            raise NotImplementedError

    last_result = test_multi(rollout_controller.get_env(), args)
    print("initial result:{}".format(last_result))
    for i in range(0, args.walks_per_iter):
        print("performing rollout {}/{}".format(i + 1, args.walks_per_iter))
        rollout_controller.perform_rollout(trj_buffer_list=traj_buffer_list,
                                           random_walk=True,
                                           walk_rate=i * 0.05,
                                           model_dirs=[MODEL_PATH],
                                           model_id=MODEL_ID,
                                           save_traj=True,
                                           verbose=1)
        for buf in traj_buffer_list:
            buf.save_to_file()
            print("successfully saved trajectories fo rollout {}".format(i +
                                                                         1))
Esempio n. 4
0
def init(args, test=False):
    tf_mute_warning()
    args.save_dir = save_dir + args.config_file[7:-5]
    if test:
        args.save_dir = save_dir + args.config_file[7:-10]

    # config_name = args.config_file.split('/')[1].split('.')[0]
    # args.agent_save_dir = args.save_dir + "/" + config_name
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    # create world
    world = World(args.config_file, thread_num=args.thread, silent=True)

    # create agents
    agents = []
    for i in world.intersections:
        action_space = gym.spaces.Discrete(len(i.phases))
        agents.append(
            FRAP_DQNAgent(
                action_space,
                LaneVehicleGenerator(world,
                                     i, ["lane_count"],
                                     in_only=True,
                                     average=None),
                LaneVehicleGenerator(world,
                                     i, ["lane_waiting_count"],
                                     in_only=True,
                                     average="all",
                                     negative=True), world, i.id))

    # create metric
    metric = TravelTimeMetric(world)

    # create env
    env = TSCEnv(world, agents, metric)

    return env
Esempio n. 5
0
# import the requred libraries
import os
import numpy as np
import tensorflow as tf
from tensorflow import gfile
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import backend as K
from mute_tf_warnings import tf_mute_warning  # this is just a libray to mute tf warnings

#mute warnings
tf_mute_warning()

#model directories
save_dir = "models/"
keras_model_name = "keras_model.h5"
tf_lite_model = "tf_lite.tflite"
tf_quant_model = "tf_quant_model.tflite"

# Build a random data set
x = np.vstack((np.random.rand(1000, 10), -np.random.rand(1000, 10)))
y = np.vstack((np.ones((1000, 1)), np.zeros((1000, 1))))

# build a sequential model
model = Sequential()
model.add(Dense(units=64, input_shape=(10, ), activation="relu"))
model.add(Dense(units=32, activation="relu"))
model.add(Dense(units=16, activation="relu"))
model.add(Dense(units=8, activation="relu"))
model.add(Dense(units=1, activation="sigmoid"))
model.compile(loss="binary_crossentropy",
Esempio n. 6
0
    def run(self):
        def rollout_single_action(args, env):
            obs = env.get_current_obs()
            start_step = env.eng.get_current_time()
            i = env.eng.get_current_time()
            while i < args.steps:
                if args.time_horizon:
                    if i >= start_step + args.time_horizon:
                        break
                actions = []
                for agent_id, agent in enumerate(env.agents):
                    if self.args.parameter_sharing:
                        actions = env.agents[0].get_actions(obs)
                        break
                    actions.append(agent.get_action(obs[agent_id]))
                for _ in range(args.action_interval):
                    obs, rewards, dones, info = env.step(actions)
                    i += 1
            result = env.eng.get_average_travel_time()
            return result

        import time
        t0 = time.time()
        import tensorflow as tf
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.env = create_env(self.args, policy=self.policy)
        self.n_intersections = len(self.env.world.intersections)
        t1 = time.time()
        # print("environment for process {} created!, time: {}".format(self.process_id, t1 - t0))
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
        tf_mute_warning()

        while True:
            if not self.action_queue.empty():
                actions = self.action_queue.get()
            else:
                time.sleep(1)
                continue

            if actions == "end":
                break
            elif isinstance(actions, dict):
                model_id = actions["model_id"]
                model_dirs = actions["model_dirs"]
                with self.graph.as_default():
                    if not self.policy == "max_pressure":
                        if self.policy == "frap":
                            for agent in self.env.agents:
                                agent.load_model(dir=model_dirs[0])
                        elif self.args.parameter_sharing:
                            self.env.agents[0].load_model(dir=model_dirs[0],
                                                          model_id=model_id)
                        else:
                            for agent_id in range(len(model_dirs)):
                                self.env.agents[agent_id].load_model(
                                    dir=model_dirs[agent_id],
                                    model_id=model_id)
                time.sleep(1)
            else:
                self.env.load_snapshot(from_file=True,
                                       dir="./archive",
                                       file_name=self.archive_name)
                obs = self.env.get_current_obs()
                with self.graph.as_default():
                    if self.args.parameter_sharing:
                        default_actions = self.env.agents[0].get_actions(obs)
                        # print(default_actions)
                    # print(actions)
                    for agent_id in range(self.n_intersections):
                        if actions[agent_id] is None:
                            if self.args.parameter_sharing:
                                actions[agent_id] = default_actions[agent_id]
                            else:
                                actions[agent_id] = self.env.agents[
                                    agent_id].get_action(obs[agent_id])
                        elif isinstance(actions[agent_id], str):
                            actions[agent_id] = int(actions[agent_id])
                            rollout_action_id = actions[agent_id]
                for _ in range(self.args.action_interval):
                    # print(actions)
                    self.env.step(actions)
                with self.graph.as_default():
                    result = rollout_single_action(self.args, self.env)
                self.result_queue.put([rollout_action_id, result])