def tune_parameters(cutoff: int, jobs: int):
    session_data, products_data = load_sessions_products()
    session_data = session_data.head(int(session_data.index.size / 2))
    params = []
    precision = 0
    recall = 0

    for i in range(cutoff):
        print(i, '/', cutoff, end='\r')
        params_test = [
            100, 20,
            uniform(0, 0.1),
            uniform(0, 1),
            uniform(0, 0.05),
            uniform(0, 0.2)
        ]
        precision_test, recall_test = test('cf', 10, session_data,
                                           products_data, jobs, params_test,
                                           True)

        if (precision_test + recall_test > precision + recall):
            params = params_test
            precision = precision_test
            recall = recall_test

    print('\nBest precision@10 =', precision)
    print('Best recall@10 =', recall)
    print('Best params:\n\tn_factors =', params[0], '\n\tn_epochs =',
          params[1], '\n\tinit_mean =', params[2], '\n\tinit_std_dev =',
          params[3], '\n\tlr_all =', params[4], '\n\treg_all =', params[5])
Пример #2
0
    def run(self, multi_process=False):

        best_round, bar_round = None, None

        f_time = open(
            os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                         "running_time.csv"), "w")
        f_time.write(
            "generator_time\tmaking_samples_time\tupdate_network_time\ttest_evaluation_times\tall_times\n"
        )
        f_time.close()

        if self.dic_exp_conf["PRETRAIN"]:
            if os.listdir(self.dic_path["PATH_TO_PRETRAIN_MODEL"]):
                for i in range(self.dic_traffic_env_conf["NUM_AGENTS"]):
                    #TODO:only suitable for DGN
                    shutil.copy(
                        os.path.join(self.dic_path["PATH_TO_PRETRAIN_MODEL"],
                                     "round_0_inter_%d.h5" % i),
                        os.path.join(self.dic_path["PATH_TO_MODEL"],
                                     "round_0_inter_%d.h5" % i))

                # shutil.copy(os.path.join(self.dic_path["PATH_TO_PRETRAIN_MODEL"],
                #                          "%s.h5" % self.dic_exp_conf["TRAFFIC_FILE"][0]),
                #             os.path.join(self.dic_path["PATH_TO_MODEL"], "round_0.h5"))
                # shutil.copy(os.path.join(self.dic_path["PATH_TO_PRETRAIN_MODEL"],
                #                          "%s.h5" % self.dic_exp_conf["TRAFFIC_FILE"][0]),
                #             os.path.join(self.dic_path["PATH_TO_MODEL"], "round_0.h5"))
            else:
                if not os.listdir(
                        self.dic_path["PATH_TO_PRETRAIN_WORK_DIRECTORY"]):
                    for cnt_round in range(
                            self.dic_exp_conf["PRETRAIN_NUM_ROUNDS"]):
                        print("round %d starts" % cnt_round)

                        process_list = []

                        # ==============  generator =============
                        if multi_process:
                            for cnt_gen in range(
                                    self.
                                    dic_exp_conf["PRETRAIN_NUM_GENERATORS"]):
                                p = Process(target=self.generator_wrapper,
                                            args=(cnt_round, cnt_gen,
                                                  self.dic_path,
                                                  self.dic_exp_conf,
                                                  self.dic_agent_conf,
                                                  self.dic_traffic_env_conf,
                                                  best_round))
                                print("before")
                                p.start()
                                print("end")
                                process_list.append(p)
                            print("before join")
                            for p in process_list:
                                p.join()
                            print("end join")
                        else:
                            for cnt_gen in range(
                                    self.
                                    dic_exp_conf["PRETRAIN_NUM_GENERATORS"]):
                                self.generator_wrapper(
                                    cnt_round=cnt_round,
                                    cnt_gen=cnt_gen,
                                    dic_path=self.dic_path,
                                    dic_exp_conf=self.dic_exp_conf,
                                    dic_agent_conf=self.dic_agent_conf,
                                    dic_traffic_env_conf=self.
                                    dic_traffic_env_conf,
                                    best_round=best_round)

                        # ==============  make samples =============
                        # make samples and determine which samples are good

                        train_round = os.path.join(
                            self.dic_path["PATH_TO_PRETRAIN_WORK_DIRECTORY"],
                            "train_round")
                        if not os.path.exists(train_round):
                            os.makedirs(train_round)
                        cs = ConstructSample(
                            path_to_samples=train_round,
                            cnt_round=cnt_round,
                            dic_traffic_env_conf=self.dic_traffic_env_conf)
                        cs.make_reward()

                if self.dic_exp_conf["MODEL_NAME"] in self.dic_exp_conf[
                        "LIST_MODEL_NEED_TO_UPDATE"]:
                    if multi_process:
                        p = Process(target=self.updater_wrapper,
                                    args=(0, self.dic_agent_conf,
                                          self.dic_exp_conf,
                                          self.dic_traffic_env_conf,
                                          self.dic_path, best_round))
                        p.start()
                        p.join()
                    else:
                        self.updater_wrapper(
                            cnt_round=0,
                            dic_agent_conf=self.dic_agent_conf,
                            dic_exp_conf=self.dic_exp_conf,
                            dic_traffic_env_conf=self.dic_traffic_env_conf,
                            dic_path=self.dic_path,
                            best_round=best_round)
        # train with aggregate samples
        if self.dic_exp_conf["AGGREGATE"]:
            if "aggregate.h5" in os.listdir("model/initial"):
                shutil.copy(
                    "model/initial/aggregate.h5",
                    os.path.join(self.dic_path["PATH_TO_MODEL"], "round_0.h5"))
            else:
                if multi_process:
                    p = Process(target=self.updater_wrapper,
                                args=(0, self.dic_agent_conf,
                                      self.dic_exp_conf,
                                      self.dic_traffic_env_conf, self.dic_path,
                                      best_round))
                    p.start()
                    p.join()
                else:
                    self.updater_wrapper(
                        cnt_round=0,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        dic_path=self.dic_path,
                        best_round=best_round)

        self.dic_exp_conf["PRETRAIN"] = False
        self.dic_exp_conf["AGGREGATE"] = False

        # trainf
        for cnt_round in range(self.dic_exp_conf["NUM_ROUNDS"]):
            print("round %d starts" % cnt_round)
            round_start_time = time.time()

            process_list = []

            print("==============  generator =============")
            generator_start_time = time.time()
            if multi_process:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    p = Process(target=self.generator_wrapper,
                                args=(cnt_round, cnt_gen, self.dic_path,
                                      self.dic_exp_conf, self.dic_agent_conf,
                                      self.dic_traffic_env_conf, best_round))
                    print("before")
                    p.start()
                    print("end")
                    process_list.append(p)
                print("before join")
                for i in range(len(process_list)):
                    p = process_list[i]
                    print("generator %d to join" % i)
                    p.join()
                    print("generator %d finish join" % i)
                print("end join")
            else:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    self.generator_wrapper(
                        cnt_round=cnt_round,
                        cnt_gen=cnt_gen,
                        dic_path=self.dic_path,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        best_round=best_round)
            generator_end_time = time.time()
            generator_total_time = generator_end_time - generator_start_time
            print("==============  make samples =============")
            # make samples and determine which samples are good
            making_samples_start_time = time.time()

            train_round = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                                       "train_round")
            if not os.path.exists(train_round):
                os.makedirs(train_round)

            cs = ConstructSample(
                path_to_samples=train_round,
                cnt_round=cnt_round,
                dic_traffic_env_conf=self.dic_traffic_env_conf)
            cs.make_reward_for_system()

            # EvaluateSample()
            making_samples_end_time = time.time()
            making_samples_total_time = making_samples_end_time - making_samples_start_time

            print("==============  update network =============")
            update_network_start_time = time.time()
            if self.dic_exp_conf["MODEL_NAME"] in self.dic_exp_conf[
                    "LIST_MODEL_NEED_TO_UPDATE"]:
                if multi_process:
                    p = Process(target=self.updater_wrapper,
                                args=(cnt_round, self.dic_agent_conf,
                                      self.dic_exp_conf,
                                      self.dic_traffic_env_conf, self.dic_path,
                                      best_round, bar_round))
                    p.start()
                    print("update to join")
                    p.join()
                    print("update finish join")
                else:
                    self.updater_wrapper(
                        cnt_round=cnt_round,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        dic_path=self.dic_path,
                        best_round=best_round,
                        bar_round=bar_round)

            if not self.dic_exp_conf["DEBUG"]:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    path_to_log = os.path.join(
                        self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_round",
                        "round_" + str(cnt_round), "generator_" + str(cnt_gen))
                    self.downsample_for_system(path_to_log,
                                               self.dic_traffic_env_conf)
            update_network_end_time = time.time()
            update_network_total_time = update_network_end_time - update_network_start_time

            print("==============  test evaluation =============")
            # <<<<<<< HEAD
            test_evaluation_start_time = time.time()
            if self.dic_exp_conf["SPARSE_TEST"] and cnt_round % 5 != 0:
                pass
# =======

#             if multi_process:
#                 p = Process(target=model_test.test,
#                             args=(self.dic_path["PATH_TO_MODEL"], cnt_round, self.dic_exp_conf["RUN_COUNTS"], self.dic_traffic_env_conf, False))
#                 p.start()
#                 if self.dic_exp_conf["EARLY_STOP"]:
#                     p.join()
# # >>>>>>> ana_simulator
            else:
                if multi_process:
                    p = Process(target=model_test.test,
                                args=(self.dic_path["PATH_TO_MODEL"],
                                      cnt_round,
                                      self.dic_exp_conf["RUN_COUNTS"],
                                      self.dic_traffic_env_conf, False))
                    p.start()
                    if self.dic_exp_conf["EARLY_STOP"]:
                        p.join()
                else:
                    model_test.test(self.dic_path["PATH_TO_MODEL"],
                                    cnt_round,
                                    self.dic_exp_conf["RUN_COUNTS"],
                                    self.dic_traffic_env_conf,
                                    if_gui=False)

            test_evaluation_end_time = time.time()
            test_evaluation_total_time = test_evaluation_end_time - test_evaluation_start_time

            print('==============  early stopping =============')
            if self.dic_exp_conf["EARLY_STOP"]:
                flag = self.early_stopping(self.dic_path, cnt_round)
                if flag == 1:
                    print("early stopping!")
                    print("training ends at round %s" % cnt_round)
                    break

            print('==============  model pool evaluation =============')
            if self.dic_exp_conf["MODEL_POOL"] and cnt_round > 50:
                if multi_process:
                    p = Process(
                        target=self.model_pool_wrapper,
                        args=(self.dic_path, self.dic_exp_conf, cnt_round),
                    )
                    p.start()
                    print("model_pool to join")
                    p.join()
                    print("model_pool finish join")
                else:
                    self.model_pool_wrapper(dic_path=self.dic_path,
                                            dic_exp_conf=self.dic_exp_conf,
                                            cnt_round=cnt_round)
                model_pool_dir = os.path.join(
                    self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model.pkl")
                if os.path.exists(model_pool_dir):
                    model_pool = pickle.load(open(model_pool_dir, "rb"))
                    ind = random.randint(0, len(model_pool) - 1)
                    best_round = model_pool[ind][0]
                    ind_bar = random.randint(0, len(model_pool) - 1)
                    flag = 0
                    while ind_bar == ind and flag < 10:
                        ind_bar = random.randint(0, len(model_pool) - 1)
                        flag += 1
                    # bar_round = model_pool[ind_bar][0]
                    bar_round = None
                else:
                    best_round = None
                    bar_round = None

                # downsample
                if not self.dic_exp_conf["DEBUG"]:
                    path_to_log = os.path.join(
                        self.dic_path["PATH_TO_WORK_DIRECTORY"], "test_round",
                        "round_" + str(cnt_round))
                    self.downsample_for_system(path_to_log,
                                               self.dic_traffic_env_conf)
            else:
                best_round = None

            print("best_round: ", best_round)

            print("Generator time: ", generator_total_time)
            print("Making samples time:", making_samples_total_time)
            print("update_network time:", update_network_total_time)
            print("test_evaluation time:", test_evaluation_total_time)

            print("round {0} ends, total_time: {1}".format(
                cnt_round,
                time.time() - round_start_time))
            f_time = open(
                os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                             "running_time.csv"), "a")
            f_time.write("{0}\t{1}\t{2}\t{3}\t{4}\n".format(
                generator_total_time, making_samples_total_time,
                update_network_total_time, test_evaluation_total_time,
                time.time() - round_start_time))
            f_time.close()
Пример #3
0
    def run(self, multi_process=False):

        best_round, bar_round = None, None

        f_time = open(
            os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                         "running_time.csv"), "w")
        f_time.write(
            "generator_time\tmaking_samples_time\tupdate_network_time\ttest_evaluation_times\tall_times\n"
        )
        f_time.close()

        # trainf
        for cnt_round in range(self.dic_exp_conf["NUM_ROUNDS"]):
            print("round %d starts" % cnt_round)
            round_start_time = time.time()

            process_list = []

            print("==============  generator =============")
            generator_start_time = time.time()
            if multi_process:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    p = Process(target=self.generator_wrapper,
                                args=(cnt_round, cnt_gen, self.dic_path,
                                      self.dic_exp_conf, self.dic_agent_conf,
                                      self.dic_traffic_env_conf, best_round))
                    print("before")
                    p.start()
                    print("end")
                    process_list.append(p)
                print("before join")
                for i in range(len(process_list)):
                    p = process_list[i]
                    print("generator %d to join" % i)
                    p.join()
                    print("generator %d finish join" % i)
                print("end join")
            else:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    self.generator_wrapper(
                        cnt_round=cnt_round,
                        cnt_gen=cnt_gen,
                        dic_path=self.dic_path,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        best_round=best_round)
            generator_end_time = time.time()
            generator_total_time = generator_end_time - generator_start_time
            print("==============  make samples =============")
            # make samples and determine which samples are good
            making_samples_start_time = time.time()

            train_round = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                                       "train_round")
            if not os.path.exists(train_round):
                os.makedirs(train_round)

            cs = ConstructSample(
                path_to_samples=train_round,
                cnt_round=cnt_round,
                dic_traffic_env_conf=self.dic_traffic_env_conf)
            cs.make_reward_for_system()

            # EvaluateSample()
            making_samples_end_time = time.time()
            making_samples_total_time = making_samples_end_time - making_samples_start_time

            print("==============  update network =============")
            update_network_start_time = time.time()
            if self.dic_exp_conf["MODEL_NAME"] in self.dic_exp_conf[
                    "LIST_MODEL_NEED_TO_UPDATE"]:
                if multi_process:
                    p = Process(target=self.updater_wrapper,
                                args=(cnt_round, self.dic_agent_conf,
                                      self.dic_exp_conf,
                                      self.dic_traffic_env_conf, self.dic_path,
                                      best_round, bar_round))
                    p.start()
                    print("update to join")
                    p.join()
                    print("update finish join")
                else:
                    self.updater_wrapper(
                        cnt_round=cnt_round,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        dic_path=self.dic_path,
                        best_round=best_round,
                        bar_round=bar_round)

            if not self.dic_exp_conf["DEBUG"]:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    path_to_log = os.path.join(
                        self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_round",
                        "round_" + str(cnt_round), "generator_" + str(cnt_gen))
                    self.downsample_for_system(path_to_log,
                                               self.dic_traffic_env_conf)
            update_network_end_time = time.time()
            update_network_total_time = update_network_end_time - update_network_start_time

            print("==============  test evaluation =============")
            test_evaluation_start_time = time.time()
            if multi_process:
                p = Process(target=model_test.test,
                            args=(self.dic_path["PATH_TO_MODEL"], cnt_round,
                                  self.dic_exp_conf["RUN_COUNTS"],
                                  self.dic_traffic_env_conf, False))
                p.start()
                if self.dic_exp_conf["EARLY_STOP"]:
                    p.join()
            else:
                model_test.test(self.dic_path["PATH_TO_MODEL"],
                                cnt_round,
                                self.dic_exp_conf["RUN_COUNTS"],
                                self.dic_traffic_env_conf,
                                if_gui=False)

            test_evaluation_end_time = time.time()
            test_evaluation_total_time = test_evaluation_end_time - test_evaluation_start_time

            print('==============  early stopping =============')
            if self.dic_exp_conf["EARLY_STOP"]:
                flag = self.early_stopping(self.dic_path, cnt_round)
                if flag == 1:
                    print("early stopping!")
                    print("training ends at round %s" % cnt_round)
                    break

            print("Generator time: ", generator_total_time)
            print("Making samples time:", making_samples_total_time)
            print("update_network time:", update_network_total_time)
            print("test_evaluation time:", test_evaluation_total_time)

            print("round {0} ends, total_time: {1}".format(
                cnt_round,
                time.time() - round_start_time))
            f_time = open(
                os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                             "running_time.csv"), "a")
            f_time.write("{0}\t{1}\t{2}\t{3}\t{4}\n".format(
                generator_total_time, making_samples_total_time,
                update_network_total_time, test_evaluation_total_time,
                time.time() - round_start_time))
            f_time.close()
Пример #4
0
    def run(self, multi_process=False):

        best_round, bar_round = None, None
        # pretrain for acceleration
        if self.dic_exp_conf["PRETRAIN"]:
            if os.listdir(self.dic_path["PATH_TO_PRETRAIN_MODEL"]):
                shutil.copy(
                    os.path.join(
                        self.dic_path["PATH_TO_PRETRAIN_MODEL"],
                        "%s.h5" % self.dic_exp_conf["TRAFFIC_FILE"][0]),
                    os.path.join(self.dic_path["PATH_TO_MODEL"], "round_0.h5"))
            else:
                if not os.listdir(
                        self.dic_path["PATH_TO_PRETRAIN_WORK_DIRECTORY"]):
                    for cnt_round in range(
                            self.dic_exp_conf["PRETRAIN_NUM_ROUNDS"]):
                        print("round %d starts" % cnt_round)

                        process_list = []

                        # ==============  generator =============
                        if multi_process:
                            for cnt_gen in range(
                                    self.
                                    dic_exp_conf["PRETRAIN_NUM_GENERATORS"]):
                                p = Process(
                                    target=self.generator_wrapper,
                                    args=(cnt_round, cnt_gen, self.dic_path,
                                          self.dic_exp_conf,
                                          self.dic_agent_conf,
                                          self.dic_sumo_env_conf, best_round))
                                print("before")
                                p.start()
                                print("end")
                                process_list.append(p)
                            print("before join")
                            for p in process_list:
                                p.join()
                            print("end join")
                        else:
                            for cnt_gen in range(
                                    self.
                                    dic_exp_conf["PRETRAIN_NUM_GENERATORS"]):
                                self.generator_wrapper(
                                    cnt_round=cnt_round,
                                    cnt_gen=cnt_gen,
                                    dic_path=self.dic_path,
                                    dic_exp_conf=self.dic_exp_conf,
                                    dic_agent_conf=self.dic_agent_conf,
                                    dic_sumo_env_conf=self.dic_sumo_env_conf,
                                    best_round=best_round)

                        # ==============  make samples =============
                        # make samples and determine which samples are good

                        train_round = os.path.join(
                            self.dic_path["PATH_TO_PRETRAIN_WORK_DIRECTORY"],
                            "train_round")
                        if not os.path.exists(train_round):
                            os.makedirs(train_round)
                        cs = ConstructSample(
                            path_to_samples=train_round,
                            cnt_round=cnt_round,
                            dic_sumo_env_conf=self.dic_sumo_env_conf)
                        cs.make_reward()

                if self.dic_exp_conf["MODEL_NAME"] in self.dic_exp_conf[
                        "LIST_MODEL_NEED_TO_UPDATE"]:
                    if multi_process:
                        p = Process(target=self.updater_wrapper,
                                    args=(0, self.dic_agent_conf,
                                          self.dic_exp_conf,
                                          self.dic_sumo_env_conf,
                                          self.dic_path, best_round))
                        p.start()
                        p.join()
                    else:
                        self.updater_wrapper(
                            cnt_round=0,
                            dic_agent_conf=self.dic_agent_conf,
                            dic_exp_conf=self.dic_exp_conf,
                            dic_sumo_env_conf=self.dic_sumo_env_conf,
                            dic_path=self.dic_path,
                            best_round=best_round)
        # train with aggregate samples
        if self.dic_exp_conf["AGGREGATE"]:
            if "aggregate.h5" in os.listdir("model/initial"):
                shutil.copy(
                    "model/initial/aggregate.h5",
                    os.path.join(self.dic_path["PATH_TO_MODEL"], "round_0.h5"))
            else:
                if multi_process:
                    p = Process(target=self.updater_wrapper,
                                args=(0, self.dic_agent_conf,
                                      self.dic_exp_conf,
                                      self.dic_sumo_env_conf, self.dic_path,
                                      best_round))
                    p.start()
                    p.join()
                else:
                    self.updater_wrapper(
                        cnt_round=0,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_sumo_env_conf=self.dic_sumo_env_conf,
                        dic_path=self.dic_path,
                        best_round=best_round)

        self.dic_exp_conf["PRETRAIN"] = False
        self.dic_exp_conf["AGGREGATE"] = False

        # train
        for cnt_round in range(self.dic_exp_conf["NUM_ROUNDS"]):
            print("round %d starts" % cnt_round)

            round_start_t = time.time()

            process_list = []

            # ==============  generator =============
            if multi_process:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    p = Process(target=self.generator_wrapper,
                                args=(cnt_round, cnt_gen, self.dic_path,
                                      self.dic_exp_conf, self.dic_agent_conf,
                                      self.dic_traffic_env_conf, best_round))
                    p.start()
                    process_list.append(p)
                for i in range(len(process_list)):
                    p = process_list[i]
                    p.join()
            else:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    self.generator_wrapper(
                        cnt_round=cnt_round,
                        cnt_gen=cnt_gen,
                        dic_path=self.dic_path,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        best_round=best_round)

            # ==============  make samples =============
            # make samples and determine which samples are good

            train_round = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                                       "train_round")
            if not os.path.exists(train_round):
                os.makedirs(train_round)
            cs = ConstructSample(
                path_to_samples=train_round,
                cnt_round=cnt_round,
                dic_traffic_env_conf=self.dic_traffic_env_conf)
            cs.make_reward()

            # EvaluateSample()

            # ==============  update network =============
            if self.dic_exp_conf["MODEL_NAME"] in self.dic_exp_conf[
                    "LIST_MODEL_NEED_TO_UPDATE"]:
                if multi_process:
                    p = Process(target=self.updater_wrapper,
                                args=(cnt_round, self.dic_agent_conf,
                                      self.dic_exp_conf,
                                      self.dic_traffic_env_conf, self.dic_path,
                                      best_round, bar_round))
                    p.start()
                    p.join()
                else:
                    self.updater_wrapper(
                        cnt_round=cnt_round,
                        dic_agent_conf=self.dic_agent_conf,
                        dic_exp_conf=self.dic_exp_conf,
                        dic_traffic_env_conf=self.dic_traffic_env_conf,
                        dic_path=self.dic_path,
                        best_round=best_round,
                        bar_round=bar_round)

            if not self.dic_exp_conf["DEBUG"]:
                for cnt_gen in range(self.dic_exp_conf["NUM_GENERATORS"]):
                    path_to_log = os.path.join(
                        self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_round",
                        "round_" + str(cnt_round), "generator_" + str(cnt_gen))
                    self.downsample(path_to_log)

            # ==============  test evaluation =============
            if multi_process:
                p = Process(target=model_test.test,
                            args=(self.dic_path["PATH_TO_MODEL"], cnt_round,
                                  self.dic_exp_conf["TEST_RUN_COUNTS"],
                                  self.dic_traffic_env_conf, False))
                p.start()
                if self.dic_exp_conf["EARLY_STOP"] or self.dic_exp_conf[
                        "MODEL_POOL"]:
                    p.join()
            else:
                model_test.test(self.dic_path["PATH_TO_MODEL"],
                                cnt_round,
                                self.dic_exp_conf["RUN_COUNTS"],
                                self.dic_traffic_env_conf,
                                if_gui=False)

            # ==============  early stopping =============
            if self.dic_exp_conf["EARLY_STOP"]:
                flag = self.early_stopping(self.dic_path, cnt_round)
                if flag == 1:
                    break

            # ==============  model pool evaluation =============
            if self.dic_exp_conf["MODEL_POOL"]:
                if multi_process:
                    p = Process(
                        target=self.model_pool_wrapper,
                        args=(self.dic_path, self.dic_exp_conf, cnt_round),
                    )
                    p.start()
                    p.join()
                else:
                    self.model_pool_wrapper(dic_path=self.dic_path,
                                            dic_exp_conf=self.dic_exp_conf,
                                            cnt_round=cnt_round)
                model_pool_dir = os.path.join(
                    self.dic_path["PATH_TO_WORK_DIRECTORY"], "best_model.pkl")
                if os.path.exists(model_pool_dir):
                    model_pool = pickle.load(open(model_pool_dir, "rb"))
                    ind = random.randint(0, len(model_pool) - 1)
                    best_round = model_pool[ind][0]
                    ind_bar = random.randint(0, len(model_pool) - 1)
                    flag = 0
                    while ind_bar == ind and flag < 10:
                        ind_bar = random.randint(0, len(model_pool) - 1)
                        flag += 1
                    # bar_round = model_pool[ind_bar][0]
                    bar_round = None
                else:
                    best_round = None
                    bar_round = None

                # downsample
                if not self.dic_exp_conf["DEBUG"]:
                    path_to_log = os.path.join(
                        self.dic_path["PATH_TO_WORK_DIRECTORY"], "test_round",
                        "round_" + str(cnt_round))
                    self.downsample(path_to_log)
            else:
                best_round = None

            print("best_round: ", best_round)

            print("round %s ends" % cnt_round)

            round_end_t = time.time()
            f_timing = open(
                os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"],
                             "timing.txt"), "a+")
            f_timing.write("round_{0}: {1}\n".format(
                cnt_round, round_end_t - round_start_t))
            f_timing.close()