def train_epoch(data_path, sess, model, train_fetches, valid_fetches,
                train_writer, test_writer):
    global last_f1
    global lr
    time0 = time.time()
    batch_indexs = np.random.permutation(
        n_tr_batches)  # shuffle the training data
    for batch in tqdm(xrange(n_tr_batches)):
        global_step = sess.run(model.global_step)
        if 0 == (global_step + 1) % FLAGS.valid_step:
            valid_cost, precision, recall, f1 = valid_epoch(
                data_valid_path, sess, model)
            print(
                'Global_step=%d: valid cost=%g; p=%g, r=%g, f1=%g, time=%g s' %
                (global_step, valid_cost, precision, recall, f1,
                 time.time() - time0))
            time0 = time.time()
            if f1 > last_f1:
                last_f1 = f1
                saving_path = model.saver.save(sess, model_path,
                                               global_step + 1)
                print('saved new model to %s ' % saving_path)
        # training
        batch_id = batch_indexs[batch]
        [X1_batch, X2_batch, y_batch] = get_batch(data_train_path, batch_id)
        y_batch = to_categorical(y_batch)  #把所有的topic id 转为 0,1形式
        _batch_size = len(y_batch)
        feed_dict = {
            model.X1_inputs: X1_batch,
            model.X2_inputs: X2_batch,
            model.y_inputs: y_batch,
            model.batch_size: _batch_size,
            model.tst: False,
            model.keep_prob: FLAGS.keep_prob
        }
        summary, _cost, _, _ = sess.run(
            train_fetches, feed_dict)  # the cost is the mean cost of one batch
        # valid per 500 steps
        if 0 == (global_step + 1) % 500:
            train_writer.add_summary(summary, global_step)
            batch_id = np.random.randint(0, n_va_batches)  # 随机选一个验证batch
            [X1_batch, X2_batch, y_batch] = get_batch(data_valid_path,
                                                      batch_id)
            y_batch = to_categorical(y_batch)
            _batch_size = len(y_batch)
            feed_dict = {
                model.X1_inputs: X1_batch,
                model.X2_inputs: X2_batch,
                model.y_inputs: y_batch,
                model.batch_size: _batch_size,
                model.tst: True,
                model.keep_prob: 1.0
            }
            summary, _cost = sess.run(valid_fetches, feed_dict)
            test_writer.add_summary(summary, global_step)
def valid_epoch(data_path, sess, model):
    """Test on the valid data."""
    va_batches = os.listdir(data_path)
    n_va_batches = len(va_batches)
    _costs = 0.0
    predict_labels_list = list()  # 所有的预测结果
    marked_labels_list = list()
    for i in xrange(n_va_batches):
        [X1_batch, X2_batch, y_batch] = get_batch(data_path, i)
        marked_labels_list.extend(y_batch)
        y_batch = to_categorical(y_batch)
        _batch_size = len(y_batch)
        fetches = [model.loss, model.y_pred]
        feed_dict = {
            model.X1_inputs: X1_batch,
            model.X2_inputs: X2_batch,
            model.y_inputs: y_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        _cost, predict_labels = sess.run(fetches, feed_dict)
        _costs += _cost
        predict_labels = map(lambda label: label.argsort()[-1:-6:-1],
                             predict_labels)  # 取最大的5个下标
        predict_labels_list.extend(predict_labels)
    predict_label_and_marked_label_list = zip(predict_labels_list,
                                              marked_labels_list)
    precision, recall, f1 = score_eval(predict_label_and_marked_label_list)
    mean_cost = _costs / n_va_batches
    return mean_cost, precision, recall, f1
def mark():
    marked_labels_list = list()
    for i in tqdm(xrange(n_va_batches)):
        [X1_batch, X2_batch, y_batch] = get_batch(i)
        marked_labels_list.extend(y_batch)
        y_batch = to_categorical(y_batch)
    print (type(marked_labels_list),marked_labels_list[0],len(marked_labels_list))
    #np.set_printoptions(threshold='nan')  #全部输出
    print(y_batch[0][0:20])
    marked_labels = np.asarray(marked_labels_list)
    print (type(marked_labels),marked_labels[0],marked_labels.shape)
    np.save('marked_labels_list_dev_all.npy',marked_labels)
def predict_test(sess, model):
    """Test on the valid data."""
    time0 = time.time()
    predict_labels_list = list()  # 所有的预测结果
    predict_labels_list2 = list() #前五名的结果
    marked_labels_list = list()
    predict_scores = list()
    for i in tqdm(xrange(n_te_batches)):#验证集
        [X1_batch, X2_batch, y_batch] = get_batch_t(i)
        marked_labels_list.extend(y_batch)#真实标签结果 没-1
        y_batch = to_categorical(y_batch)
        _batch_size = len(X1_batch)
        fetches = [model.y_pred]#每个类别的分数
        feed_dict = {model.X1_inputs: X1_batch, model.X2_inputs: X2_batch,
                     model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0}
        predict_labels = sess.run(fetches, feed_dict)[0]
        predict_labels = softmax(predict_labels)#128
        predict_scores.append(predict_labels)#每个类别的分数
        predict_top5score = map(lambda label: np.sort(label,axis=-1)[-1:-6:-1], predict_labels)  # 取最大的5个分数 128
        index = map(findindex,predict_top5score)#list 128
        #print (index,'index.type:',type(index),'len.index',len(index))
        predict_toplabels = list()

        for i in range(len(index)):
            if index[i] == None:
                toplabel = predict_labels[i].argsort()[-1:-6:-1]
            elif index[i] == 0:
                toplabel = predict_labels[i].argsort()[-1:-2:-1]
            else:
                toplabel = predict_labels[i].argsort()[-1:-1*index[i]-1:-1]
            predict_toplabels.append(toplabel)
        predict_labels_list.extend(predict_toplabels) 
    predict_label_and_marked_label_list = zip(predict_labels_list, marked_labels_list)#都-1了 不知道为啥

    print (predict_label_and_marked_label_list[0:2])
    #(array([ 15, 327, 307, 478,  10]), [8, 15, 307, 0]),真实是[9, 16, 308, 1]
    precision, recall, f1 = score_eval(predict_label_and_marked_label_list)#计算分数
    print('Local test p=%g, r=%g, f1=%g' % (precision, recall, f1))
    predict_scores = np.vstack(np.asarray(predict_scores))
    print('predict_scores:',predict_scores.shape)
    local_scores_name = local_scores_path + model_name + '_test.npy'
    np.save(local_scores_name, predict_scores)#保存每个类别的分数
    print('local_scores.shape=', predict_scores.shape)
    print('Writed the test scores into %s, time %g s' % (local_scores_name, time.time() - time0))
def predict_dev(sess, model):
    """Test on the valid data."""
    time0 = time.time()
    predict_labels_list = list()  # 所有的预测结果
    predict_score20_list = list() # 预测排名前20的分数
    predict_labels_list2 = list() #前五名的结果
    marked_labels_list = list()
    topic_num = list()
    predict_scores = list()
    for i in tqdm(xrange(n_va_batches)):#验证集
        [X1_batch, X2_batch, y_batch] = get_batch(i)
        marked_labels_list.extend(y_batch)#真实标签结果 没-1
        y_batch = to_categorical(y_batch)
        _batch_size = len(X1_batch)
        fetches = [model.y_pred]#每个类别的分数
        feed_dict = {model.X1_inputs: X1_batch, model.X2_inputs: X2_batch,
                     model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0}
        predict_labels = sess.run(fetches, feed_dict)[0]
        predict_labels = softmax(predict_labels)#128
        predict_scores.append(predict_labels)#每个类别的分数


        predict_top5score = map(lambda label: np.sort(label,axis=-1)[-1:-6:-1], predict_labels)  # 取最大的5个分数 128
        #predict_top20score = map(lambda label: np.sort(label,axis=-1)[-1:-21:-1], predict_labels)  # 取最大的20个分数 128
        #print (type(predict_score20_list))
        #print (type(predict_top20score))
        #predict_score20_list.extend(predict_top20score) #所有
        #list,predict_score_list1[0]=[ 0.63514245  0.09193601  0.0417341   0.02742104  0.02721145]

        index = map(findindex,predict_top5score)#list 128
        #print (index,'index.type:',type(index),'len.index',len(index))

        predict_toplabels = list()

        for i in range(len(index)):
            if index[i] == None:
                toplabel = predict_labels[i].argsort()[-1:-6:-1]
            elif index[i] == 0:
                toplabel = predict_labels[i].argsort()[-1:-2:-1]
            else:
                toplabel = predict_labels[i].argsort()[-1:-1*index[i]-1:-1]
            predict_toplabels.append(toplabel)

        predict_labels_list.extend(predict_toplabels) 
        #print('predict_toplabels:',predict_toplabels,type(predict_toplabels),len(predict_toplabels))


        #predict_top5labels = map(lambda label: label.argsort()[-1:-6:-1], predict_labels)  # 取最大的5个下标
        #predict_labels_list2.extend(predict_top5labels)
        
        #predict_labels_list2.to_csv('predict_labels_list2.csv')


    #predict_score20_list = DataFrame(predict_score20_list)
    #predict_labels_list2 = DataFrame(predict_labels_list2)
    #predict_score20_list.to_csv('score20list.csv')
    #predict_labels_list2.to_csv('predict_labels_list2.csv')
    #topic_num = map(tolen,marked_labels_list)
    #topic_num = DataFrame(topic_num)
    #topic_num.to_csv('topic_num.csv')
    predict_label_and_marked_label_list = zip(predict_labels_list, marked_labels_list)#都-1了 不知道为啥

    print (predict_label_and_marked_label_list[0:2])
    #(array([ 15, 327, 307, 478,  10]), [8, 15, 307, 0]),真实是[9, 16, 308, 1]
    precision, recall, f1 = score_eval(predict_label_and_marked_label_list)#计算分数
    print('Local valid p=%g, r=%g, f1=%g' % (precision, recall, f1))
    predict_scores = np.vstack(np.asarray(predict_scores))
    print('predict_scores:',predict_scores.shape)
    local_scores_name = local_scores_path + model_name + '_dev.npy'
    #np.save(local_scores_name, predict_scores)#保存每个类别的分数
    print('local_scores.shape=', predict_scores.shape)
    print('Writed the dev scores into %s, time %g s' % (local_scores_name, time.time() - time0))