예제 #1
0
파일: DQN.py 프로젝트: axzs0987/firefly
    def create_notebook_pool():
        notebook_pool = []
        in_result = []
        cursor, db = create_connection()
        sql = 'select distinct notebook_id from result'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_result.append(int(row[0]))
        in_notebook = []
        sql = 'select distinct id from notebook where isRandom=1'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_notebook.append(int(row[0]))

        sql = 'select pair.nid from pair,dataset where pair.did=dataset.id and dataset.server_ip = \'' + ip + '\''
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_result:
                continue
            if int(row[0]) not in in_notebook:
                continue
            if int(row[0]) not in notebook_pool:
                notebook_pool.append(int(row[0]))
        return notebook_pool
예제 #2
0
파일: DQN.py 프로젝트: axzs0987/firefly
 def check_model(notebook_id):
     model_dic = eval(CONFIG.get('models', 'model_dic'))
     cursor, db = create_connection()
     sql = 'select model_type from result where notebook_id = ' + str(notebook_id)
     cursor.execute(sql)
     sql_res = cursor.fetchall()
     model_list = np.zeros([len(model_dic.keys())])
     check = False
     for row in sql_res:
         if row[0] in model_dic.keys():
             model_id = model_dic[row[0]]-1
             model_list[model_id] = 1
             check = True
     return check,model_list
예제 #3
0
def train(notebook_root, dataset_root, ip):
    def check_model(notebook_id):
        model_dic = eval(CONFIG.get('models', 'model_dic'))
        cursor, db = create_connection()
        sql = 'select model_type from result where notebook_id = ' + str(
            notebook_id)
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        model_list = np.zeros([len(model_dic.keys())])
        check = False
        for row in sql_res:
            if row[0] in model_dic.keys():
                model_id = model_dic[row[0]] - 1
                model_list[model_id] = 1
                check = True
        return check, model_list

    def create_notebook_pool():
        notebook_pool = []
        in_result = []
        cursor, db = create_connection()
        sql = 'select distinct notebook_id from result'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_result.append(int(row[0]))

        in_notebook = []
        sql = 'select distinct id from notebook where isRandom=1'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_notebook.append(int(row[0]))

        sql = 'select pair.nid from pair,dataset where pair.did=dataset.id and dataset.server_ip = \'' + ip + '\''
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_result:
                continue
            if int(row[0]) not in in_notebook:
                continue
            if int(row[0]) not in notebook_pool:
                notebook_pool.append(int(row[0]))
        return notebook_pool

    notebook_pool = create_notebook_pool()
    train_config = eval(CONFIG.get('train', 'train'))
    nepisode = train_config['nepisode']
    obs_dim = train_config['obs_dim']
    ope_dic = eval(CONFIG.get('operators', 'operations'))
    learning_rate = train_config['learning_rate']
    gamma = train_config['gamma']
    dense_dim = train_config['dense_dim']
    column_num = train_config['column_num']
    act_1_dim = 0
    for item in ope_dic:
        if ope_dic[item]['index'] > act_1_dim:
            act_1_dim = ope_dic[item]['index']  # 27

    agent = PolicyGradient(act_1_dim=act_1_dim + 1,
                           act_2_dim=column_num,
                           obs_dim=obs_dim,
                           dense_dim=dense_dim,
                           lr=learning_rate,
                           gamma=gamma)

    if os.path.exists('reward_list.npy'):
        print('exists')
        reward_list = list(np.load('./reward_list.npy', allow_pickle=True))
    else:
        reward_list = []

    if os.path.exists('loss_pg.npy'):
        print('loss exists')
        loss_list = list(np.load('./loss_pg.npy', allow_pickle=True))
    else:
        loss_list = []
    iteration = 0
    if os.path.exists('act_reward_pg.npy'):
        print('exists')
        act_reward = np.load('./act_reward_pg.npy', allow_pickle=True).item()
    else:
        act_reward = {}
    for i_episode in range(nepisode):
        ep_rwd = 0
        notebook_id = random.choice(notebook_pool)
        print("\033[0;35;40m" + "notebook_id:" + str(notebook_id) + "\033[0m")
        notebook_path = notebook_root + str(notebook_id) + '.ipynb'
        notebook_code = get_code_txt(notebook_path)
        res_line_number = -1
        s_t, len_data = rpc_client_get_origin_state(notebook_id, notebook_code,
                                                    column_num, ip)
        # print(s_t)
        check_result, model_list = check_model(notebook_id)
        while s_t == 'run failed' or check_result == False:
            notebook_pool.remove(notebook_id)
            notebook_id = random.choice(notebook_pool)
            print("\033[0;34;40m" + "notebook_id:" + str(notebook_id) +
                  "\033[0m")
            notebook_path = notebook_root + str(notebook_id) + '.ipynb'
            notebook_code = get_code_txt(notebook_path)
            s_t, len_data = rpc_client_get_origin_state(
                notebook_id, notebook_code, column_num, ip)
            check_result, model_list = check_model(notebook_id)

        s_t_p = s_t
        s_t = np.ravel(s_t)
        type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
        if len(s_t) == 1900:
            s_t = np.concatenate((type_, s_t), axis=0)
        if len(s_t) == 1901:
            s_t = np.concatenate((s_t, model_list), axis=0)
        # pprint.pprint(s_t)
        if len(s_t) == 0:
            continue
        while True:
            terminal1 = False
            if int(np.load('type.npy', allow_pickle=True)) != 1:
                terminal1 = True
            action1, action2 = agent.step(s_t)  # 已知当前状态,通过网络预测预测下一步的动作(这里要改)
            # print(len_data)
            # print(len(s_t_p))
            check_res = check_action_by_rule(action1 + 1,
                                             action2,
                                             s_t_p,
                                             len_data,
                                             column_num=column_num)
            count = 0
            s_t_plus_1 = np.zeros([1942])
            if check_res == False:
                reward = -1.0
                terminal = True
            # while cehck_res == False and count < 10:
            #     print('changed_act1:',action1)
            #     count += 1
            #     action1,action2 = agent.step_1()  # 已知当前状态,通过网络预测预测下一步的动作(这里要改)
            #     cehck_res = check_action_by_rule(action1 + 1, action2 + 1, s_t_p,len_data,column_num=column_num)
            else:
                if action2 == column_num - 1:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': -1,
                    }
                else:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': action2,
                    }
                # print('?>>?')
                # print('act:',act)
                # 执行动作,得到新状态,立即回报,是否终止
                # s_t = []
                eventlet.monkey_patch()
                try:
                    with eventlet.Timeout(60, False):  # 设置超时时间为2秒
                        s_t, action, reward, terminal, s_t_plus_1, notebook_code, res_line_number, len_data_plus_1 = rpc_client_do_an_action(
                            notebook_id, notebook_code, target_content,
                            column_num, res_line_number, ip)
                except:
                    break
                if s_t == []:
                    break
                # print('?>>?')
                if reward == -2:
                    continue
                # print("\033[0;36;40m" + "s_t:" + str(s_t) + "\033[0m")

                s_t = np.ravel(s_t)
                type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
                if int(np.load('type.npy', allow_pickle=True)) != 1:
                    terminal = True
                if len(s_t) == 1900:
                    s_t = np.concatenate((type_, s_t), axis=0)
                if len(s_t) == 1901:
                    s_t = np.concatenate((s_t, model_list), axis=0)

                s_t_p = s_t_plus_1
                s_t_plus_1 = np.ravel(s_t_plus_1)
                if len(s_t_plus_1) == 1900:
                    s_t_plus_1 = np.concatenate(([0], s_t_plus_1), axis=0)
                if len(s_t_plus_1) == 1901:
                    s_t_plus_1 = np.concatenate((s_t_plus_1, model_list),
                                                axis=0)
                s_t = s_t_plus_1
                len_data = len_data_plus_1
            act = (action1, action2)
            agent.memory.store_transition(s_t, act, reward)  # 放入采样池(这里要改)
            ep_rwd += reward
            print("\033[0;36;40m" + "reward:" + str(reward) + "\033[0m")
            print("\033[0;36;40m" + "terminal:" + str(terminal) + "\033[0m")
            print("\033[0;36;40m" + "act:" + str(act) + "\033[0m")
            if terminal or terminal1:
                loss = agent.learn()  # 一个完整过程终止,开始优化网络
                loss_list.append(loss)

                print('Ep: %i' % i_episode, "|Ep_r: %f" % ep_rwd)
                sql = 'update notebook set trained_time = trained_time + 1 where id=' + str(
                    notebook_id)
                cursor, db = create_connection()
                cursor.execute(sql)
                db.commit()
                if act[0] not in act_reward:
                    act_reward[act[0]] = []
                act_reward[act[0]].append(ep_rwd)
                reward_list.append(ep_rwd)
                # if i_episode % 50 == 0:
                np.save('./reward_list.npy', reward_list)
                np.save('loss_pg', loss_list)
                np.save('act_reward_pg', act_reward)
                break