Exemple #1
0
def train(notebook_root, dataset_root, ip):
    def check_model(notebook_id):
        model_dic = eval(CONFIG.get('models', 'model_dic'))
        cursor, db = create_connection()
        sql = 'select model_type from result where notebook_id = ' + str(
            notebook_id)
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        model_list = np.zeros([len(model_dic.keys())])
        check = False
        for row in sql_res:
            if row[0] in model_dic.keys():
                model_id = model_dic[row[0]] - 1
                model_list[model_id] = 1
                check = True
        return check, model_list

    def create_notebook_pool():
        notebook_pool = []
        in_result = []
        cursor, db = create_connection()
        sql = 'select distinct notebook_id from result'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_result.append(int(row[0]))
        in_notebook = []
        sql = 'select distinct id from notebook where isRandom=1'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_notebook.append(int(row[0]))

        in_ope = []
        sql = 'select distinct notebook_id from operator'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_ope:
                in_ope.append(int(row[0]))

        sql = 'select pair.nid from pair,dataset where pair.did=dataset.id and dataset.server_ip = \'' + ip + '\''
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_result:
                continue
            if int(row[0]) not in in_notebook:
                continue
            if int(row[0]) in in_ope:
                continue
            if int(row[0]) not in notebook_pool:
                notebook_pool.append(int(row[0]))
        print('ntoebook_pool:', len(notebook_pool))
        return notebook_pool

    notebook_pool = create_notebook_pool()
    train_config = eval(CONFIG.get('train', 'train'))
    nepisode = train_config['nepisode']
    obs_dim = train_config['obs_dim']
    ope_dic = eval(CONFIG.get('operators', 'operations'))
    learning_rate = train_config['learning_rate']
    gamma = train_config['gamma']
    dense_dim = train_config['dense_dim']
    column_num = train_config['column_num']
    act_1_dim = 0
    for item in ope_dic:
        if ope_dic[item]['index'] > act_1_dim:
            act_1_dim = ope_dic[item]['index']  # 27

    agent = ActorCritic(act_dim_1=act_1_dim + 1,
                        act_dim_2=column_num,
                        obs_dim=obs_dim,
                        lr_actor=learning_rate,
                        lr_value=learning_rate,
                        gamma=gamma)

    if os.path.exists('reward_list_d1m4.npy'):
        print('reward_list_A2C_1 exists')
        reward_list = list(np.load('./reward_list_d1m4.npy',
                                   allow_pickle=True))
    else:
        reward_list = []

    if os.path.exists('reward_list_r_d1m4.npy'):
        print('reward_list_A2C_r_1 exists')
        reward_list_r = list(
            np.load('./reward_list_r_d1m4.npy', allow_pickle=True))
    else:
        reward_list_r = []

    if os.path.exists('./act_reward_d1m4.npy'):
        print('exists')
        act_reward = np.load('./act_reward_d1m4.npy', allow_pickle=True).item()
    else:
        act_reward = {}

    if os.path.exists('max_reward_d1m4.npy'):
        print('max_reward_A2C_1 exists')
        max_reward = list(np.load('./max_reward_d1m4.npy', allow_pickle=True))
    else:
        max_reward = []
    if os.path.exists('value_loss_d1m4.npy'):
        print('value_loss_d1m4 exists')
        value_loss_list = list(
            np.load('./value_loss_d1m4.npy', allow_pickle=True))
    else:
        value_loss_list = []
    if os.path.exists('actor_loss_d1m4.npy'):
        print('actor_loss_d1m4 exists')
        actor_loss_list = list(
            np.load('./actor_loss_d1m4.npy', allow_pickle=True))
    else:
        actor_loss_list = []

    if os.path.exists('suceed_action_d1m4.npy'):
        print('suceed_action exists')
        suceed_action = list(
            np.load('./suceed_action_d1m4.npy', allow_pickle=True))
    else:
        suceed_action = []
    if os.path.exists('fail_action_d1m4.npy'):
        print('fail_action exists')
        fail_action = list(np.load('./fail_action_d1m4.npy',
                                   allow_pickle=True))
    else:
        fail_action = []

    for i_episode in range(nepisode):
        ep_rwd = 0
        notebook_id = random.choice(notebook_pool)
        print("\033[0;35;40m" + "notebook_id:" + str(notebook_id) + "\033[0m")
        notebook_path = notebook_root + str(notebook_id) + '.ipynb'
        notebook_code = get_code_txt(notebook_path)
        res_line_number = -1
        s_t, len_data = rpc_client_get_origin_state(notebook_id, notebook_code,
                                                    column_num, ip)
        check_result, model_list = check_model(notebook_id)
        while s_t == 'run failed' or check_result == False:
            notebook_pool.remove(notebook_id)
            notebook_id = random.choice(notebook_pool)
            print("\033[0;34;40m" + "notebook_id:" + str(notebook_id) +
                  "\033[0m")
            notebook_path = notebook_root + str(notebook_id) + '.ipynb'
            notebook_code = get_code_txt(notebook_path)
            s_t, len_data = rpc_client_get_origin_state(
                notebook_id, notebook_code, column_num, ip)
            check_result, model_list = check_model(notebook_id)

        s_t_p = s_t
        s_t = np.ravel(s_t)
        type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
        if len(s_t) == 1900:
            s_t = np.concatenate((type_, s_t), axis=0)
        if len(s_t) == 1901:
            s_t = np.concatenate((s_t, model_list), axis=0)
        if len(s_t) == 0:
            continue

        temp_act_reward_dic = []
        chosed_list = []
        while True:
            # act, _ = agent.step(obs0)  # act是通过概率选择的一个动作,_是所有动作的Q值
            terminal1 = False
            if int(np.load('type.npy', allow_pickle=True)) != 1:
                terminal1 = True
            # agent.look_weight()
            action1, act_prob1, action2, act_prob2, random_chose = agent.step(
                s_t, s_t_p, len_data)

            check_res = check_action_by_rule(action1 + 1,
                                             action2 + 1,
                                             s_t_p,
                                             len_data,
                                             column_num=column_num)
            s_t_plus_1 = np.zeros([1942])  # failed s_t_plus_1
            try_time = 0
            while (str(action1) + '+' + str(action2) in chosed_list
                   or str(action1) + '+' + str(100)
                   in chosed_list) and try_time < 20:
                action1, act_prob1, action2, act_prob2, random_chose = agent.step(
                    s_t, s_t_p, len_data)
                check_res = check_action_by_rule(action1 + 1,
                                                 action2 + 1,
                                                 s_t_p,
                                                 len_data,
                                                 column_num=column_num)
                if check_res == False:
                    continue
                try_time += 1
            chosed_list.append(str(action1) + '+' + str(action2))
            print('check_res', check_res)
            if check_res == False:
                reward = -1.0
                terminal = True
                compare = False
            else:
                if action2 == column_num - 1:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': -1,
                    }
                    compare = compare_state(s_t, s_t_plus_1)
                    if compare == True:
                        print('action1:', action1)
                else:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': action2,
                    }

                # try:
                s_t, action, reward, terminal, s_t_plus_1, notebook_code, res_line_number, len_data_plus_1 = rpc_client_do_an_action(
                    notebook_id, notebook_code, target_content, column_num,
                    res_line_number, ip)
                # except Exception as e:
                #     print(e)
                #     break
                if s_t == []:
                    print('st is nulls')
                    break

                if reward == -2:
                    print('??')
                    continue

                s_t = np.ravel(s_t)
                try:
                    type_ = np.array(
                        [int(np.load('type.npy', allow_pickle=True))])
                    if int(np.load('type.npy', allow_pickle=True)) != 1:
                        terminal = True
                    # os.system('rm -f type.npy')
                except:
                    type_ = np.array([0])
                if len(s_t) == 1900:
                    s_t = np.concatenate((type_, s_t), axis=0)
                if len(s_t) == 1901:
                    s_t = np.concatenate((s_t, model_list), axis=0)
                if s_t_plus_1 == []:
                    np.zeros([1942])
                    reward = -1
                    terminal = True
                s_t_p = s_t_plus_1
                s_t_plus_1 = np.ravel(s_t_plus_1)
                try:
                    type_1 = np.array(
                        [int(np.load('type_1.npy', allow_pickle=True))])
                    if int(np.load('type_1.npy', allow_pickle=True)) != 1:
                        terminal1 = True
                    # os.system('rm -f type_1.npy')
                except:
                    type_1 = np.array([0])
                if len(s_t_plus_1) == 1900:
                    s_t_plus_1 = np.concatenate((type_1, s_t_plus_1), axis=0)
                if len(s_t_plus_1) == 1901:
                    s_t_plus_1 = np.concatenate((s_t_plus_1, model_list),
                                                axis=0)

                compare = compare_state(s_t, s_t_plus_1)
                if compare == True:
                    print('action1:', action1)
                    reward = -0.5

                # else:
                #     reward = -1
                # s_t = s_t_plus_1
                len_data = len_data_plus_1
            act = (action1, action2)
            # if str(action1) not in act_reward.keys():
            #     act_reward[str(action1)] = []
            # if str(action1) not in temp_act_reward_dic.keys():
            #     temp_act_reward_dic[str(action1)] = []

            # act_reward[str(action1)].append((act_prob1, action2, act_prob2, reward, 'not changed:'+str(compare), 'terminal:'+str(terminal), 'is random chosed:'+str(random_chose)))
            temp_act_reward_dic.append(
                (action1, action2, act_prob1, act_prob2,
                 'reward:' + str(reward), 'not changed:' + str(compare),
                 'terminal:' + str(terminal),
                 'is random chosed:' + str(random_chose)))
            # print('s_t:', s_t)
            # print('s_t_plus_1:',s_t_plus_1)

            if reward > 0:
                suceed_action.append((notebook_id, act))
                reward *= 1000
            if reward < 0 and reward != -1:
                fail_action.append((notebook_id, act))
            if reward == 0:
                reward = 0.5
            agent.memory.store_transition(s_t, act[0], act[1], reward)
            ep_rwd += reward
            reward_list_r.append(reward)
            if reward == -1:
                s_t_plus_1 = np.zeros([1942])
            s_t = s_t_plus_1
            print("\033[0;36;40m" + "reward:" + str(reward) + "\033[0m")
            print("\033[0;36;40m" + "terminal:" + str(terminal) + "\033[0m")
            print("\033[0;36;40m" + "act:" + str(act) + "\033[0m")
            if random_chose == False:
                max_reward.append(reward)
            if terminal or terminal1:
                last_value = agent.get_value(
                    [s_t])  # last_value是执行完一轮,最后一个操作的所有动作的Q值
                value_loss, actor_loss, act_advantage_dic = agent.learn(
                    last_value, terminal)
                for index, item in enumerate(temp_act_reward_dic):
                    action1 = item[0]
                    if str(action1) not in act_reward.keys():
                        act_reward[str(action1)] = []
                    act_reward[str(action1)].append(
                        (item[2], item[1], item[3], item[4],
                         'advantage:' + str(act_advantage_dic[index][2]),
                         'actor_loss:' + str(act_advantage_dic[index][3]),
                         item[5], item[6], item[7]))
                value_loss_list.append(value_loss)
                actor_loss_list.append(actor_loss)
                print('Ep: %i' % i_episode, "|Ep_r: %i" % ep_rwd)
                reward_list.append(ep_rwd)
                np.save('reward_list_d1m4', reward_list)
                np.save('reward_list_r_d1m4', reward_list_r)
                np.save('act_reward_d1m4', act_reward)
                np.save('max_reward_d1m4', max_reward)
                np.save('suceed_action_d1m4', suceed_action)
                np.save('fail_action_d1m4', fail_action)
                np.save('value_loss_d1m4', value_loss_list)
                np.save('actor_loss_d1m4', actor_loss_list)
                break
Exemple #2
0
def addSamples(notebook_root, dataset_root, ip):
    def create_notebook_pool():
        notebook_pool = []
        in_result = []
        cursor, db = create_connection()
        sql = 'select distinct notebook_id from result'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_result.append(int(row[0]))
        sql = 'select pair.nid from pair,dataset where pair.did=dataset.id and dataset.server_ip = \'' + ip + '\''
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_result:
                continue
            if int(row[0]) not in notebook_pool:
                notebook_pool.append(int(row[0]))
        return notebook_pool

    notebook_pool = create_notebook_pool()
    train_config = eval(CONFIG.get('train', 'train'))
    nepisode = train_config['nepisode']
    obs_dim = train_config['obs_dim']
    ope_dic = eval(CONFIG.get('operators', 'operations'))
    learning_rate = train_config['learning_rate']
    gamma = train_config['gamma']
    dense_dim = train_config['dense_dim']
    column_num = train_config['column_num']
    act_1_dim = 0
    for item in ope_dic:
        if ope_dic[item]['index'] > act_1_dim:
            act_1_dim = ope_dic[item]['index']  # 27

    # agent = PolicyGradient(act_1_dim=act_1_dim, act_2_dim=column_num, obs_dim=obs_dim, dense_dim=dense_dim, lr=learning_rate, gamma=gamma)
    for i_episode in range(nepisode):
        ep_rwd = 0
        notebook_id = random.choice(notebook_pool)
        print("\033[0;35;40m" + "notebook_id:" + str(notebook_id) + "\033[0m")
        notebook_path = notebook_root + str(notebook_id) + '.ipynb'
        notebook_code = get_code_txt(notebook_path)
        res_line_number = -1
        s_t = get_origin_state(notebook_id,
                               notebook_code,
                               column_num,
                               dataset_root=dataset_root)
        print(s_t)
        while s_t == 'run failed':
            notebook_pool.remove(notebook_id)
            notebook_id = random.choice(notebook_pool)
            print("\033[0;34;40m" + "notebook_id:" + str(notebook_id) +
                  "\033[0m")
            notebook_path = notebook_root + str(notebook_id) + '.ipynb'
            notebook_code = get_code_txt(notebook_path)
            s_t = get_origin_state(notebook_id,
                                   notebook_code,
                                   column_num,
                                   dataset_root=dataset_root)
        print(s_t)
        while True:
            s_t = np.ravel(s_t)
            action1, action2 = agent.step(s_t)  # 已知当前状态,通过网络预测预测下一步的动作(这里要改)
            target_content = {
                'operation': action1,
                'data_object': action2,
            }
            act = (action1, action2)
            s_t, action, reward, terminal, s_t_plus_1, notebook_code, res_line_number = do_an_action(
                notebook_id, notebook_code, target_content, column_num,
                res_line_number)  # 执行动作,得到新状态,立即回报,是否终止
            print("\033[0;36;40m" + "reward:" + str(reward) + "\033[0m")
            print("\033[0;36;40m" + "terminal:" + str(terminal) + "\033[0m")
            s_t = np.ravel(s_t)
            s_t_plus_1 = np.ravel(s_t_plus_1)
            agent.memory.store_transition(s_t, act, reward)  # 放入采样池(这里要改)
            s_t = s_t_plus_1
            ep_rwd += reward

            if terminal:
                agent.learn()  # 一个完整过程终止,开始优化网络
                print('Ep: %i' % i_episode, "|Ep_r: %i" % ep_rwd)
                break
Exemple #3
0
def train(notebook_root,dataset_root,ip):
    def check_model(notebook_id):
        model_dic = eval(CONFIG.get('models', 'model_dic'))
        cursor, db = create_connection()
        sql = 'select model_type from result where notebook_id = ' + str(notebook_id)
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        model_list = np.zeros([len(model_dic.keys())])
        check = False
        for row in sql_res:
            if row[0] in model_dic.keys():
                model_id = model_dic[row[0]]-1
                model_list[model_id] = 1
                check = True
        return check,model_list

    def create_notebook_pool():
        notebook_pool = []
        in_result = []
        cursor, db = create_connection()
        sql = 'select distinct notebook_id from result'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_result.append(int(row[0]))
        in_notebook = []
        sql = 'select distinct id from notebook where isRandom=1'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_notebook.append(int(row[0]))

        sql = 'select pair.nid from pair,dataset where pair.did=dataset.id and dataset.server_ip = \'' + ip + '\''
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_result:
                continue
            if int(row[0]) not in in_notebook:
                continue
            if int(row[0]) not in notebook_pool:
                notebook_pool.append(int(row[0]))
        return notebook_pool


    notebook_pool = create_notebook_pool()
    train_config = eval(CONFIG.get('train', 'train'))
    nepisode = train_config['nepisode']
    obs_dim = train_config['obs_dim']
    ope_dic = eval(CONFIG.get('operators', 'operations'))
    learning_rate = train_config['learning_rate']
    gamma = train_config['gamma']
    dense_dim = train_config['dense_dim']
    column_num = train_config['column_num']
    epsilon = train_config['epsilon']
    epsilon_step = train_config['epsilon_step']
    act_1_dim = 0
    batch_size = train_config['batch_size']
    for item in ope_dic:
        if ope_dic[item]['index'] > act_1_dim:
            act_1_dim = ope_dic[item]['index']  # 27

    agent = DQN(act_dim_1=act_1_dim + 1, act_dim_2=column_num, obs_dim=obs_dim, dense_dim=dense_dim,
        lr_q_value=learning_rate, gamma=gamma, epsilon=epsilon, batch_size=batch_size)

    if os.path.exists('reward_list_dqn_1.npy'):
        print('reward_list_dqn_1 exists')
        reward_list = list(np.load('./reward_list_dqn_1.npy',allow_pickle=True))
    else:
        reward_list = []

    if os.path.exists('loss_list_dqn.npy'):
        print('loss exists')
        loss_list = list(np.load('./loss_list_dqn.npy',allow_pickle=True))
    else:
        loss_list = []
    iteration = 0
    if os.path.exists('act_reward_dqn.npy'):
        print('exists')
        act_reward = np.load('./act_reward.npy',allow_pickle=True).item()
    else:
        act_reward = {}
    for i_episode in range(nepisode):
        ep_rwd = 0
        notebook_id = random.choice(notebook_pool)
        print("\033[0;35;40m" + "notebook_id:" + str(notebook_id) + "\033[0m")
        notebook_path = notebook_root + str(notebook_id) + '.ipynb'
        notebook_code = get_code_txt(notebook_path)
        res_line_number = -1
        s_t, len_data = rpc_client_get_origin_state(notebook_id, notebook_code, column_num, ip)
        check_result, model_list = check_model(notebook_id)
        while s_t == 'run failed' or check_result == False:
            notebook_pool.remove(notebook_id)
            notebook_id = random.choice(notebook_pool)
            print("\033[0;34;40m" + "notebook_id:" + str(notebook_id) + "\033[0m")
            notebook_path = notebook_root + str(notebook_id) + '.ipynb'
            notebook_code = get_code_txt(notebook_path)
            s_t, len_data = rpc_client_get_origin_state(notebook_id, notebook_code, column_num, ip)
            check_result, model_list = check_model(notebook_id)

        s_t_p = s_t
        s_t = np.ravel(s_t)
        type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
        if len(s_t) == 1900:
            s_t = np.concatenate((type_, s_t), axis=0)
        if len(s_t) == 1901:
            s_t = np.concatenate((s_t, model_list), axis=0)
        if len(s_t) == 0:
            continue
        while True:
            terminal1 = False
            if int(np.load('type.npy', allow_pickle=True)) != 1:
                terminal1 = True

            action1, action2 = agent.step(s_t, len_data)
            check_res = check_action_by_rule(action1 + 1, action2 + 1, s_t_p, len_data, column_num=column_num)
            s_t_plus_1 = np.zeros([1942])
            if check_res == False:
                reward = -1.0
                terminal = True

            else:
                if action2 == column_num - 1:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': -1,
                    }
                else:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': action2,
                    }

                try:
                    s_t, action, reward, terminal, s_t_plus_1, notebook_code, res_line_number, len_data_plus_1 = rpc_client_do_an_action(
                        notebook_id, notebook_code, target_content, column_num, res_line_number, ip)
                except:
                    break
                if s_t == []:
                    break

                if reward == -2:
                    continue

                s_t = np.ravel(s_t)
                type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
                if int(np.load('type.npy', allow_pickle=True)) != 1:
                    terminal = True
                if len(s_t) == 1900:
                    s_t = np.concatenate((type_, s_t), axis=0)
                if len(s_t) == 1901:
                    s_t = np.concatenate((s_t, model_list), axis=0)

                s_t_p = s_t_plus_1
                s_t_plus_1 = np.ravel(s_t_plus_1)
                if len(s_t_plus_1) == 1900:
                    s_t_plus_1 = np.concatenate(([0], s_t_plus_1), axis=0)
                if len(s_t_plus_1) == 1901:
                    s_t_plus_1 = np.concatenate((s_t_plus_1, model_list), axis=0)
                s_t = s_t_plus_1
                len_data = len_data_plus_1
            act = (action1, action2)
            if reward > 0:
                reward *= 1000

            if reward == 0:
                reward = 0.5

            agent.memory.store_transition(s_t, act[0], act[1], reward, s_t_plus_1, terminal)
            agent.memory.save_buffer()
            ep_rwd += reward
            print('iteration',iteration)
            print("\033[0;36;40m" + "reward:" + str(reward) + "\033[0m")
            print("\033[0;36;40m" + "terminal:" + str(terminal) + "\033[0m")
            print("\033[0;36;40m" + "act:" + str(act) + "\033[0m")

            iteration += 1
            if iteration >= 20:
                loss = agent.learn()
                loss_list.append(loss)
                if iteration % epsilon_step == 0:
                    agent.epsilon = max([agent.epsilon * 0.99, 0.001])
                if act[0] not in act_reward:
                    act_reward[act[0]] = []
                act_reward[act[0]].append(ep_rwd)
            if terminal or terminal1:
                print('Ep: %i' % i_episode, "|Ep_r: %i" % ep_rwd)
                reward_list.append(ep_rwd)
                np.save('reward_list_dqn_1', reward_list)
                np.save('loss_list_dqn', loss_list)
                np.save('act_reward_dqn', act_reward)
                np.save('max_reward_test_3', max_reward)
                np.save('suceed_action_test_3', suceed_action)
                np.save('fail_action_test_3', fail_action)
                np.save('value_loss_test_3', value_loss_list)
                np.save('actor_loss_test_3', actor_loss_list)
                break
Exemple #4
0
def train(notebook_root, dataset_root, ip):
    def check_model(notebook_id):
        model_dic = eval(CONFIG.get('models', 'model_dic'))
        cursor, db = create_connection()
        sql = 'select model_type from result where notebook_id = ' + str(
            notebook_id)
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        model_list = np.zeros([len(model_dic.keys())])
        check = False
        for row in sql_res:
            if row[0] in model_dic.keys():
                model_id = model_dic[row[0]] - 1
                model_list[model_id] = 1
                check = True
        return check, model_list

    def create_notebook_pool():
        notebook_pool = []
        in_result = []
        cursor, db = create_connection()
        sql = 'select distinct notebook_id from result'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_result.append(int(row[0]))

        in_notebook = []
        sql = 'select distinct id from notebook where isRandom=1'
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            in_notebook.append(int(row[0]))

        sql = 'select pair.nid from pair,dataset where pair.did=dataset.id and dataset.server_ip = \'' + ip + '\''
        cursor.execute(sql)
        sql_res = cursor.fetchall()
        for row in sql_res:
            if int(row[0]) not in in_result:
                continue
            if int(row[0]) not in in_notebook:
                continue
            if int(row[0]) not in notebook_pool:
                notebook_pool.append(int(row[0]))
        return notebook_pool

    notebook_pool = create_notebook_pool()
    train_config = eval(CONFIG.get('train', 'train'))
    nepisode = train_config['nepisode']
    obs_dim = train_config['obs_dim']
    ope_dic = eval(CONFIG.get('operators', 'operations'))
    learning_rate = train_config['learning_rate']
    gamma = train_config['gamma']
    dense_dim = train_config['dense_dim']
    column_num = train_config['column_num']
    act_1_dim = 0
    for item in ope_dic:
        if ope_dic[item]['index'] > act_1_dim:
            act_1_dim = ope_dic[item]['index']  # 27

    agent = PolicyGradient(act_1_dim=act_1_dim + 1,
                           act_2_dim=column_num,
                           obs_dim=obs_dim,
                           dense_dim=dense_dim,
                           lr=learning_rate,
                           gamma=gamma)

    if os.path.exists('reward_list.npy'):
        print('exists')
        reward_list = list(np.load('./reward_list.npy', allow_pickle=True))
    else:
        reward_list = []

    if os.path.exists('loss_pg.npy'):
        print('loss exists')
        loss_list = list(np.load('./loss_pg.npy', allow_pickle=True))
    else:
        loss_list = []
    iteration = 0
    if os.path.exists('act_reward_pg.npy'):
        print('exists')
        act_reward = np.load('./act_reward_pg.npy', allow_pickle=True).item()
    else:
        act_reward = {}
    for i_episode in range(nepisode):
        ep_rwd = 0
        notebook_id = random.choice(notebook_pool)
        print("\033[0;35;40m" + "notebook_id:" + str(notebook_id) + "\033[0m")
        notebook_path = notebook_root + str(notebook_id) + '.ipynb'
        notebook_code = get_code_txt(notebook_path)
        res_line_number = -1
        s_t, len_data = rpc_client_get_origin_state(notebook_id, notebook_code,
                                                    column_num, ip)
        # print(s_t)
        check_result, model_list = check_model(notebook_id)
        while s_t == 'run failed' or check_result == False:
            notebook_pool.remove(notebook_id)
            notebook_id = random.choice(notebook_pool)
            print("\033[0;34;40m" + "notebook_id:" + str(notebook_id) +
                  "\033[0m")
            notebook_path = notebook_root + str(notebook_id) + '.ipynb'
            notebook_code = get_code_txt(notebook_path)
            s_t, len_data = rpc_client_get_origin_state(
                notebook_id, notebook_code, column_num, ip)
            check_result, model_list = check_model(notebook_id)

        s_t_p = s_t
        s_t = np.ravel(s_t)
        type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
        if len(s_t) == 1900:
            s_t = np.concatenate((type_, s_t), axis=0)
        if len(s_t) == 1901:
            s_t = np.concatenate((s_t, model_list), axis=0)
        # pprint.pprint(s_t)
        if len(s_t) == 0:
            continue
        while True:
            terminal1 = False
            if int(np.load('type.npy', allow_pickle=True)) != 1:
                terminal1 = True
            action1, action2 = agent.step(s_t)  # 已知当前状态,通过网络预测预测下一步的动作(这里要改)
            # print(len_data)
            # print(len(s_t_p))
            check_res = check_action_by_rule(action1 + 1,
                                             action2,
                                             s_t_p,
                                             len_data,
                                             column_num=column_num)
            count = 0
            s_t_plus_1 = np.zeros([1942])
            if check_res == False:
                reward = -1.0
                terminal = True
            # while cehck_res == False and count < 10:
            #     print('changed_act1:',action1)
            #     count += 1
            #     action1,action2 = agent.step_1()  # 已知当前状态,通过网络预测预测下一步的动作(这里要改)
            #     cehck_res = check_action_by_rule(action1 + 1, action2 + 1, s_t_p,len_data,column_num=column_num)
            else:
                if action2 == column_num - 1:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': -1,
                    }
                else:
                    target_content = {
                        'operation': action1 + 1,
                        'data_object': action2,
                    }
                # print('?>>?')
                # print('act:',act)
                # 执行动作,得到新状态,立即回报,是否终止
                # s_t = []
                eventlet.monkey_patch()
                try:
                    with eventlet.Timeout(60, False):  # 设置超时时间为2秒
                        s_t, action, reward, terminal, s_t_plus_1, notebook_code, res_line_number, len_data_plus_1 = rpc_client_do_an_action(
                            notebook_id, notebook_code, target_content,
                            column_num, res_line_number, ip)
                except:
                    break
                if s_t == []:
                    break
                # print('?>>?')
                if reward == -2:
                    continue
                # print("\033[0;36;40m" + "s_t:" + str(s_t) + "\033[0m")

                s_t = np.ravel(s_t)
                type_ = np.array([int(np.load('type.npy', allow_pickle=True))])
                if int(np.load('type.npy', allow_pickle=True)) != 1:
                    terminal = True
                if len(s_t) == 1900:
                    s_t = np.concatenate((type_, s_t), axis=0)
                if len(s_t) == 1901:
                    s_t = np.concatenate((s_t, model_list), axis=0)

                s_t_p = s_t_plus_1
                s_t_plus_1 = np.ravel(s_t_plus_1)
                if len(s_t_plus_1) == 1900:
                    s_t_plus_1 = np.concatenate(([0], s_t_plus_1), axis=0)
                if len(s_t_plus_1) == 1901:
                    s_t_plus_1 = np.concatenate((s_t_plus_1, model_list),
                                                axis=0)
                s_t = s_t_plus_1
                len_data = len_data_plus_1
            act = (action1, action2)
            agent.memory.store_transition(s_t, act, reward)  # 放入采样池(这里要改)
            ep_rwd += reward
            print("\033[0;36;40m" + "reward:" + str(reward) + "\033[0m")
            print("\033[0;36;40m" + "terminal:" + str(terminal) + "\033[0m")
            print("\033[0;36;40m" + "act:" + str(act) + "\033[0m")
            if terminal or terminal1:
                loss = agent.learn()  # 一个完整过程终止,开始优化网络
                loss_list.append(loss)

                print('Ep: %i' % i_episode, "|Ep_r: %f" % ep_rwd)
                sql = 'update notebook set trained_time = trained_time + 1 where id=' + str(
                    notebook_id)
                cursor, db = create_connection()
                cursor.execute(sql)
                db.commit()
                if act[0] not in act_reward:
                    act_reward[act[0]] = []
                act_reward[act[0]].append(ep_rwd)
                reward_list.append(ep_rwd)
                # if i_episode % 50 == 0:
                np.save('./reward_list.npy', reward_list)
                np.save('loss_pg', loss_list)
                np.save('act_reward_pg', act_reward)
                break