Ejemplo n.º 1
0
def main():
    # Prepare training and testing data
    opt = Options()
    # load data
    loadpath = "./data/mimic3.p"
    embpath = "mimic3_emb.p"
    opt.num_class = 50

    x = cPickle.load(open(loadpath, "rb"))
    train, train_text, train_lab = x[0], x[1], x[2]
    val, val_text, val_lab = x[3], x[4], x[5]
    test, test_text, test_lab = x[6], x[7], x[8]
    wordtoix, ixtoword = x[10], x[9]
    del x
    print("load data finished")

    train_lab = np.array(train_lab, dtype='float32')
    val_lab = np.array(val_lab, dtype='float32')
    test_lab = np.array(test_lab, dtype='float32')
    opt.n_words = len(ixtoword)
    if opt.part_data:
        #np.random.seed(123)
        train_ind = np.random.choice(len(train),
                                     int(len(train) * opt.portion),
                                     replace=False)
        train = [train[t] for t in train_ind]
        train_lab = [train_lab[t] for t in train_ind]

    os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.GPUID)

    print(dict(opt))
    print('Total words: %d' % opt.n_words)

    try:
        opt.W_emb = np.array(cPickle.load(open(embpath, 'rb')),
                             dtype='float32')
        opt.W_class_emb = load_class_embedding(wordtoix, opt)
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32,
                            shape=[opt.batch_size, opt.maxlen],
                            name='x_')
        x_mask_ = tf.placeholder(tf.float32,
                                 shape=[opt.batch_size, opt.maxlen],
                                 name='x_mask_')
        keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        y_ = tf.placeholder(tf.float32,
                            shape=[opt.batch_size, opt.num_class],
                            name='y_')
        class_penalty_ = tf.placeholder(tf.float32, shape=())
        accuracy_, loss_, train_op, W_norm_, global_step, logits_, prob_ = emb_classifier(
            x_, x_mask_, y_, keep_prob, opt, class_penalty_)
    uidx = 0
    max_val_accuracy = 0.
    max_test_accuracy = 0.
    max_val_auc_mean = 0.
    max_test_auc_mean = 0.

    config = tf.ConfigProto(
        log_device_placement=False,
        allow_soft_placement=True,
    )
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                save_keys = tensors_key_in_file(opt.save_path)
                ss = set([var.name for var in t_vars]) & set(
                    [s + ":0" for s in save_keys.keys()])
                cc = {var.name: var for var in t_vars}
                # only restore variables with correct shape
                ss_right_shape = set(
                    [s for s in ss if cc[s].get_shape() == save_keys[s[:-2]]])

                loader = tf.train.Saver(var_list=[
                    var for var in t_vars if var.name in ss_right_shape
                ])
                loader.restore(sess, opt.save_path)

                print("Loading variables from '%s'." % opt.save_path)
                print("Loaded variables:" + str(ss))

            except:
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        try:
            for epoch in range(opt.max_epochs):
                print("Starting epoch %d" % epoch)
                kf = get_minibatches_idx(len(train),
                                         opt.batch_size,
                                         shuffle=True)
                for _, train_index in kf:
                    uidx += 1
                    sents = [train[t] for t in train_index]
                    x_labels = [train_lab[t] for t in train_index]
                    x_labels = np.array(x_labels)
                    x_labels = x_labels.reshape((len(x_labels), opt.num_class))

                    x_batch, x_batch_mask = prepare_data_for_emb(sents, opt)
                    _, loss, step, = sess.run(
                        [train_op, loss_, global_step],
                        feed_dict={
                            x_: x_batch,
                            x_mask_: x_batch_mask,
                            y_: x_labels,
                            keep_prob: opt.dropout,
                            class_penalty_: opt.class_penalty
                        })

                    if uidx % opt.valid_freq == 0:
                        train_correct = 0.0
                        # sample evaluate accuaccy on 500 sample data
                        kf_train = get_minibatches_idx(500,
                                                       opt.batch_size,
                                                       shuffle=True)
                        for _, train_index in kf_train:
                            train_sents = [train[t] for t in train_index]
                            train_labels = [train_lab[t] for t in train_index]
                            train_labels = np.array(train_labels)
                            train_labels = train_labels.reshape(
                                (len(train_labels), opt.num_class))
                            x_train_batch, x_train_batch_mask = prepare_data_for_emb(
                                train_sents, opt)
                            train_accuracy = sess.run(accuracy_,
                                                      feed_dict={
                                                          x_: x_train_batch,
                                                          x_mask_:
                                                          x_train_batch_mask,
                                                          y_: train_labels,
                                                          keep_prob: 1.0,
                                                          class_penalty_: 0.0
                                                      })

                            train_correct += train_accuracy * len(train_index)

                        train_accuracy = train_correct / 500

                        print("Iteration %d: Training loss %f " % (uidx, loss))
                        print("Train accuracy %f " % train_accuracy)

                        val_correct = 0.0
                        val_y = []
                        val_logits_list = []
                        val_prob_list = []
                        val_true_list = []

                        kf_val = get_minibatches_idx(len(val),
                                                     opt.batch_size,
                                                     shuffle=True)
                        for _, val_index in kf_val:
                            val_sents = [val[t] for t in val_index]
                            val_labels = [val_lab[t] for t in val_index]
                            val_labels = np.array(val_labels)
                            val_labels = val_labels.reshape(
                                (len(val_labels), opt.num_class))
                            x_val_batch, x_val_batch_mask = prepare_data_for_emb(
                                val_sents, opt)
                            val_accuracy, val_logits, val_probs = sess.run(
                                [accuracy_, logits_, prob_],
                                feed_dict={
                                    x_: x_val_batch,
                                    x_mask_: x_val_batch_mask,
                                    y_: val_labels,
                                    keep_prob: 1.0,
                                    class_penalty_: 0.0
                                })

                            val_correct += val_accuracy * len(val_index)
                            val_y += np.argmax(val_labels, axis=1).tolist()
                            val_logits_list += val_logits.tolist()
                            val_prob_list += val_probs.tolist()
                            val_true_list += val_labels.tolist()

                        val_accuracy = val_correct / len(val)
                        val_logits_array = np.asarray(val_logits_list)
                        val_prob_array = np.asarray(val_prob_list)
                        val_true_array = np.asarray(val_true_list)
                        val_auc_list = []
                        val_auc_micro = roc_auc_score(y_true=val_true_array,
                                                      y_score=val_logits_array,
                                                      average='micro')
                        val_auc_macro = roc_auc_score(y_true=val_true_array,
                                                      y_score=val_logits_array,
                                                      average='macro')
                        for i in range(opt.num_class):
                            if np.max(val_true_array[:, i] > 0):
                                val_auc = roc_auc_score(
                                    y_true=val_true_array[:, i],
                                    y_score=val_logits_array[:, i],
                                )
                                val_auc_list.append(val_auc)
                        val_auc_mean = np.mean(val_auc)

                        # print("Validation accuracy %f " % val_accuracy)
                        print("val auc macro %f micro %f " %
                              (val_auc_macro, val_auc_micro))

                        if True:
                            test_correct = 0.0
                            test_y = []
                            test_logits_list = []
                            test_prob_list = []
                            test_true_list = []

                            kf_test = get_minibatches_idx(len(test),
                                                          opt.batch_size,
                                                          shuffle=True)
                            for _, test_index in kf_test:
                                test_sents = [test[t] for t in test_index]
                                test_labels = [test_lab[t] for t in test_index]
                                test_labels = np.array(test_labels)
                                test_labels = test_labels.reshape(
                                    (len(test_labels), opt.num_class))
                                x_test_batch, x_test_batch_mask = prepare_data_for_emb(
                                    test_sents, opt)

                                test_accuracy, test_logits, test_probs = sess.run(
                                    [accuracy_, logits_, prob_],
                                    feed_dict={
                                        x_: x_test_batch,
                                        x_mask_: x_test_batch_mask,
                                        y_: test_labels,
                                        keep_prob: 1.0,
                                        class_penalty_: 0.0
                                    })

                                test_correct += test_accuracy * len(test_index)

                                test_correct += test_accuracy * len(test_index)
                                test_y += np.argmax(test_labels,
                                                    axis=1).tolist()
                                test_logits_list += test_logits.tolist()
                                test_prob_list += test_probs.tolist()
                                test_true_list += test_labels.tolist()
                            test_accuracy = test_correct / len(test)
                            test_logits_array = np.asarray(test_logits_list)
                            test_prob_array = np.asarray(test_prob_list)
                            test_true_array = np.asarray(test_true_list)
                            test_auc_list = []
                            test_auc_micro = roc_auc_score(
                                y_true=test_true_array,
                                y_score=test_logits_array,
                                average='micro')
                            test_auc_macro = roc_auc_score(
                                y_true=test_true_array,
                                y_score=test_logits_array,
                                average='macro')

                            test_f1_micro = micro_f1(
                                test_prob_array.ravel() > 0.5,
                                test_true_array.ravel(),
                            )
                            test_f1_macro = macro_f1(
                                test_prob_array > 0.5,
                                test_true_array,
                            )
                            test_p5 = precision_at_k(test_logits_array,
                                                     test_true_array, 5)

                            for i in range(opt.num_class):
                                if np.max(test_true_array[:, i] > 0):
                                    test_auc = roc_auc_score(
                                        y_true=test_true_array[:, i],
                                        y_score=test_logits_array[:, i],
                                    )
                                    test_auc_list.append(test_auc)

                            test_auc_mean = np.mean(test_auc)
                            print("Test auc macro %f micro %f " %
                                  (test_auc_macro, test_auc_micro))
                            print("Test f1 macro %f micro %f " %
                                  (test_f1_macro, test_f1_micro))
                            print("P5 %f" % test_p5)
                            # max_test_accuracy = test_accuracy
                            max_test_auc_mean = test_auc_mean
                            # print("Test accuracy %f " % test_accuracy)
                            # max_test_accuracy = test_accuracy

                # print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy))
                print("Epoch %d: Max Test auc %f" % (epoch, max_test_auc_mean))
                saver.save(sess, opt.save_path, global_step=epoch)
            print("Max Test accuracy %f " % max_test_accuracy)

        except KeyboardInterrupt:
            print('Training interupted')
            print("Max Test accuracy %f " % max_test_accuracy)
Ejemplo n.º 2
0
def main():
    # Prepare training and testing data
    opt = Options()
    # load data
    if opt.dataset == 'Tweet':
        loadpath = "./data/langdetect_tweet0.7.p"
        embpath = "./data/langdetect_tweet_emb.p"
        opt.num_class = 4
        opt.class_name = ['apple', 'google', 'microsoft', 'twitter']
    if opt.dataset == 'N20short':
        loadpath = "./data/N20short.p"
        embpath = "./data/N20short_emb.p"
        opt.class_name = [
            'rec.autos', 'talk.politics.misc', 'sci.electronics',
            'comp.sys.ibm.pc.hardware', 'talk.politics.guns', 'sci.med',
            'rec.motorcycles', 'soc.religion.christian',
            'comp.sys.mac.hardware', 'comp.graphics', 'sci.space',
            'alt.atheism', 'rec.sport.baseball', 'comp.windows.x',
            'talk.religion.misc', 'comp.os.ms-windows.misc', 'misc.forsale',
            'talk.politics.mideast', 'sci.crypt', 'rec.sport.hockey'
        ]
        opt.num_class = len(opt.class_name)
    elif opt.dataset == 'agnews':
        loadpath = "./data/ag_news.p"
        embpath = "./data/ag_news_glove.p"
        opt.num_class = 4
        opt.class_name = ['World', 'Sports', 'Business', 'Science']
    elif opt.dataset == 'dbpedia':
        loadpath = "./data/dbpedia.p"
        embpath = "./data/dbpedia_glove.p"
        opt.num_class = 14
        opt.class_name = [
            'Company',
            'Educational Institution',
            'Artist',
            'Athlete',
            'Office Holder',
            'Mean Of Transportation',
            'Building',
            'Natural Place',
            'Village',
            'Animal',
            'Plant',
            'Album',
            'Film',
            'Written Work',
        ]
    elif opt.dataset == 'yelp_full':
        loadpath = "./data/yelp_full.p"
        embpath = "./data/yelp_full_glove.p"
        opt.num_class = 5
        opt.class_name = ['worst', 'bad', 'middle', 'good', 'best']
    x = cPickle.load(open(loadpath, "rb"), encoding='iso-8859-1')
    train, val, test = x[0], x[1], x[2]
    print(len(val))
    train_lab, val_lab, test_lab = x[3], x[4], x[5]
    wordtoix, ixtoword = x[6], x[7]
    del x
    print("len of train,val,test:", len(train), len(val), len(test))
    print("load data finished")

    train_lab = np.array(train_lab, dtype='float32')
    val_lab = np.array(val_lab, dtype='float32')
    test_lab = np.array(test_lab, dtype='float32')
    opt.n_words = len(ixtoword)
    if opt.part_data:
        #np.random.seed(123)
        train_ind = np.random.choice(len(train),
                                     int(len(train) * opt.portion),
                                     replace=False)
        train = [train[t] for t in train_ind]
        train_lab = [train_lab[t] for t in train_ind]

    os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.GPUID)

    print(dict(opt))
    print('Total words: %d' % opt.n_words)

    try:
        opt.W_emb = np.array(cPickle.load(open(embpath, 'rb'),
                                          encoding='iso-8859-1'),
                             dtype='float32')
        opt.W_class_emb = load_class_embedding(wordtoix, opt)
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/cpu:0'):
        x_ = tf.placeholder(tf.int32,
                            shape=[opt.batch_size, opt.maxlen],
                            name='x_')
        x_mask_ = tf.placeholder(tf.float32,
                                 shape=[opt.batch_size, opt.maxlen],
                                 name='x_mask_')
        keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        y_ = tf.placeholder(tf.float32,
                            shape=[opt.batch_size, opt.num_class],
                            name='y_')
        class_penalty_ = tf.placeholder(tf.float32, shape=())
        accuracy_, loss_, train_op, W_norm_, global_step, prob_ = emb_classifier(
            x_, x_mask_, y_, keep_prob, opt, class_penalty_)
    uidx = 0
    max_val_accuracy = 0.
    max_test_accuracy = 0.

    config = tf.ConfigProto(
        log_device_placement=False,
        allow_soft_placement=True,
    )
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                save_keys = tensors_key_in_file(opt.save_path)
                ss = set([var.name for var in t_vars]) & set(
                    [s + ":0" for s in save_keys.keys()])
                cc = {var.name: var for var in t_vars}
                # only restore variables with correct shape
                ss_right_shape = set(
                    [s for s in ss if cc[s].get_shape() == save_keys[s[:-2]]])

                loader = tf.train.Saver(var_list=[
                    var for var in t_vars if var.name in ss_right_shape
                ])
                loader.restore(sess, opt.save_path)

                print("Loading variables from '%s'." % opt.save_path)
                print("Loaded variables:" + str(ss))

            except:
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        try:
            for epoch in range(opt.max_epochs):
                print("Starting epoch %d" % epoch)
                kf = get_minibatches_idx(len(train),
                                         opt.batch_size,
                                         shuffle=True)
                for _, train_index in kf:
                    uidx += 1
                    sents = [train[t] for t in train_index]
                    x_labels = [train_lab[t] for t in train_index]
                    # print(x_labels)
                    x_labels = np.array(x_labels)
                    x_labels = x_labels.reshape((len(x_labels), opt.num_class))
                    # print(x_labels)
                    # exit()
                    x_batch, x_batch_mask = prepare_data_for_emb(sents, opt)
                    _, loss, step, = sess.run(
                        [train_op, loss_, global_step],
                        feed_dict={
                            x_: x_batch,
                            x_mask_: x_batch_mask,
                            y_: x_labels,
                            keep_prob: opt.dropout,
                            class_penalty_: opt.class_penalty
                        })

                    if uidx % opt.valid_freq == 0:
                        train_correct = 0.0
                        # sample evaluate accuaccy on 500 sample data
                        kf_train = get_minibatches_idx(500,
                                                       opt.batch_size,
                                                       shuffle=True)
                        for _, train_index in kf_train:
                            train_sents = [train[t] for t in train_index]
                            train_labels = [train_lab[t] for t in train_index]
                            train_labels = np.array(train_labels)
                            train_labels = train_labels.reshape(
                                (len(train_labels), opt.num_class))
                            x_train_batch, x_train_batch_mask = prepare_data_for_emb(
                                train_sents, opt)
                            train_accuracy = sess.run(accuracy_,
                                                      feed_dict={
                                                          x_: x_train_batch,
                                                          x_mask_:
                                                          x_train_batch_mask,
                                                          y_: train_labels,
                                                          keep_prob: 1.0,
                                                          class_penalty_: 0.0
                                                      })

                            train_correct += train_accuracy * len(train_index)

                        train_accuracy = train_correct / 500

                        print("Iteration %d: Training loss %f " % (uidx, loss))
                        print("Train accuracy %f " % train_accuracy)

                        val_correct = 0.0
                        kf_val = get_minibatches_idx(len(val),
                                                     opt.batch_size,
                                                     shuffle=True)
                        for _, val_index in kf_val:
                            val_sents = [val[t] for t in val_index]
                            val_labels = [val_lab[t] for t in val_index]
                            val_labels = np.array(val_labels)
                            val_labels = val_labels.reshape(
                                (len(val_labels), opt.num_class))
                            x_val_batch, x_val_batch_mask = prepare_data_for_emb(
                                val_sents, opt)

                            val_accuracy = sess.run(accuracy_,
                                                    feed_dict={
                                                        x_: x_val_batch,
                                                        x_mask_:
                                                        x_val_batch_mask,
                                                        y_: val_labels,
                                                        keep_prob: 1.0,
                                                        class_penalty_: 0.0
                                                    })
                            val_correct += val_accuracy * len(val_index)

                        val_accuracy = val_correct / len(val)
                        print("Validation accuracy %f " % val_accuracy)

                        if val_accuracy > max_val_accuracy:
                            max_val_accuracy = val_accuracy

                            # test_correct = 0.0
                            #
                            # kf_test = get_minibatches_idx(len(test), opt.batch_size, shuffle=True)
                            # for _, test_index in kf_test:
                            #     test_sents = [test[t] for t in test_index]
                            #     test_labels = [test_lab[t] for t in test_index]
                            #     test_labels = np.array(test_labels)
                            #     test_labels = test_labels.reshape((len(test_labels), opt.num_class))
                            #     x_test_batch, x_test_batch_mask = prepare_data_for_emb(test_sents, opt)
                            #
                            #     test_accuracy,predict_prob = sess.run([accuracy_,prob_],feed_dict={x_: x_test_batch, x_mask_: x_test_batch_mask,y_: test_labels, keep_prob: 1.0, class_penalty_: 0.0})
                            #     print(predict_prob)
                            #     test_correct += test_accuracy * len(test_index)
                            #
                            # test_accuracy = test_correct / len(test)
                            # print("Test accuracy %f " % test_accuracy)
                            # max_test_accuracy = test_accuracy

                # print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy))
                saver.save(sess, opt.save_path, global_step=epoch)
                saver.save(sess, "save_model/model.ckpt")
            # print("Max Test accuracy %f " % max_test_accuracy)

            test_correct = 0.0

            kf_test = get_minibatches_idx(len(test),
                                          opt.batch_size,
                                          shuffle=False)
            for _, test_index in kf_test:
                test_sents = [test[t] for t in test_index]
                test_labels = [test_lab[t] for t in test_index]
                test_labels = np.array(test_labels)
                test_labels = test_labels.reshape(
                    (len(test_labels), opt.num_class))
                x_test_batch, x_test_batch_mask = prepare_data_for_emb(
                    test_sents, opt)

                test_accuracy, predict_prob = sess.run(
                    [accuracy_, prob_],
                    feed_dict={
                        x_: x_test_batch,
                        x_mask_: x_test_batch_mask,
                        y_: test_labels,
                        keep_prob: 1.0,
                        class_penalty_: 0.0
                    })

                for prob in predict_prob:
                    topnlabel_onedoc = [0] * opt.num_class
                    for iter_topnlabel in range(opt.topnlabel):
                        index_label = np.argwhere(prob == max(prob))
                        topnlabel_onedoc[index_label[0]
                                         [0]] = prob[index_label][0][0]
                        prob[index_label] = -1
                    topnlabel_docwithoutlabel.append(topnlabel_onedoc)
                test_correct += test_accuracy * len(test_index)
            print(topnlabel_docwithoutlabel)
            test_accuracy = test_correct / len(test)
            print("Predict accuracy %f " % test_accuracy)

            max_test_accuracy = test_accuracy

            filename = 'test'
            file = open(filename, 'w')
            file.write(str(len(test)))
            file.write('\n')
            # print(wordtoix.get('close'))
            # exit()
            for topic_prob in topnlabel_docwithoutlabel:
                print(topic_prob)
                for prob_each_label in topic_prob:
                    file.write(str(prob_each_label))
                    file.write(" ")
                file.write('\n')

        except KeyboardInterrupt:
            print('Training interupted')
            print("Max Test accuracy %f " % max_test_accuracy)
Ejemplo n.º 3
0
args = argparser.parse_args()

vocabpath = './data1/vocab.p'
word2idx, idx2word = pickle.load(open(vocabpath, 'rb'), encoding='latin1')
args.word2idx = word2idx
args.idx2word = idx2word

args.vocab_size = len(idx2word)
print('Total words: %d' % args.vocab_size)
print('batch size: %d' % args.batch_size)
print('learning rate: %.4f' % args.lr)

class_name = ['好', '不好']
args.num_class = len(class_name)
embpath = './data1/data_fast.p'
try:
    args.W_emb = np.array(pickle.load(open(embpath, 'rb'),
                                      encoding='latin1')[0],
                          dtype='float32')
    args.W_class_emb = load_class_embedding(word2idx, args.W_emb, class_name)
    args.W_emb = torch.FloatTensor(args.W_emb)
    args.W_class_emb = torch.FloatTensor(args.W_class_emb)
except IOError:
    print('No embedding file found.')

if args.predict == 'None':
    training(args)
else:
    predict(args)
Ejemplo n.º 4
0
def main():
    # Prepare training and testing data
    opt = Options()
    # load data
    if opt.dataset == 'yahoo':
        loadpath = "./data/yahoo.p"
        embpath = "./data/yahoo_glove.p"
        opt.num_class = 10
        opt.class_name = [
            'Society Culture', 'Science Mathematics', 'Health',
            'Education Reference', 'Computers Internet', 'Sports',
            'Business Finance', 'Entertainment Music', 'Family Relationships',
            'Politics Government'
        ]
    elif opt.dataset == 'agnews':
        loadpath = "./data/ag_news.p"
        embpath = "./data/ag_news_glove.p"
        opt.num_class = 4
        opt.class_name = ['World', 'Sports', 'Business', 'Science']
    elif opt.dataset == 'dbpedia':
        loadpath = "./data/dbpedia.p"
        embpath = "./data/dbpedia_glove.p"
        opt.num_class = 14
        opt.class_name = [
            'Company',
            'Educational Institution',
            'Artist',
            'Athlete',
            'Office Holder',
            'Mean Of Transportation',
            'Building',
            'Natural Place',
            'Village',
            'Animal',
            'Plant',
            'Album',
            'Film',
            'Written Work',
        ]
    elif opt.dataset == 'yelp_full':
        loadpath = "./data/yelp_full.p"
        embpath = "./data/yelp_full_glove.p"
        opt.num_class = 5
        opt.class_name = ['worst', 'bad', 'middle', 'good', 'best']

    elif opt.dataset == 'tweets_cleaned':
        loadpath = './data/tweet_cleaned.p'
        embpath = './data/tweet_cleaned_200d_glove.p'
        opt.num_class = 4
        opt.class_name = ['arrow', 'cult', 'until', 'master']

    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    train_lab, val_lab, test_lab = x[6], x[7], x[8]
    wordtoix, ixtoword = x[9], x[10]
    del x
    print("loading data finished")

    # convert labels into float

    train_lab = np.array(train_lab, dtype='float32')
    val_lab = np.array(val_lab, dtype='float32')
    test_lab = np.array(test_lab, dtype='float32')
    opt.n_words = len(ixtoword)
    if opt.part_data:
        #np.random.seed(123)
        train_ind = np.random.choice(len(train),
                                     int(len(train) * opt.portion),
                                     replace=False)
        train = [train[t] for t in train_ind]
        train_lab = [train_lab[t] for t in train_ind]

    os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.GPUID)

    print(dict(opt))
    print('Total words: %d' % opt.n_words)

    try:
        opt.W_emb = np.array(cPickle.load(open(embpath, 'rb'))[0],
                             dtype='float32')
        opt.W_class_emb = load_class_embedding(wordtoix, opt)
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32,
                            shape=[opt.batch_size, opt.maxlen],
                            name='x_')
        x_mask_ = tf.placeholder(tf.float32,
                                 shape=[opt.batch_size, opt.maxlen],
                                 name='x_mask_')
        keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        y_ = tf.placeholder(tf.float32,
                            shape=[opt.batch_size, opt.num_class],
                            name='y_')
        class_penalty_ = tf.placeholder(tf.float32, shape=())
        accuracy_, loss_, train_op, W_norm_, global_step = emb_classifier(
            x_, x_mask_, y_, keep_prob, opt, class_penalty_)
    uidx = 0
    max_val_accuracy = 0.
    max_test_accuracy = 0.

    config = tf.ConfigProto(
        log_device_placement=False,
        allow_soft_placement=True,
    )
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                save_keys = tensors_key_in_file(opt.save_path)
                ss = set([var.name for var in t_vars]) & set(
                    [s + ":0" for s in save_keys.keys()])
                cc = {var.name: var for var in t_vars}
                # only restore variables with correct shape
                ss_right_shape = set(
                    [s for s in ss if cc[s].get_shape() == save_keys[s[:-2]]])

                loader = tf.train.Saver(var_list=[
                    var for var in t_vars if var.name in ss_right_shape
                ])
                loader.restore(sess, opt.save_path)

                print("Loading variables from '%s'." % opt.save_path)
                print("Loaded variables:" + str(ss))

            except:
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        try:
            for epoch in range(opt.max_epochs):
                print("Starting epoch %d" % epoch)
                kf = get_minibatches_idx(len(train),
                                         opt.batch_size,
                                         shuffle=True)
                for _, train_index in kf:
                    uidx += 1
                    sents = [train[t] for t in train_index]
                    x_labels = [train_lab[t] for t in train_index]
                    x_labels = np.array(x_labels)
                    x_labels = x_labels.reshape((len(x_labels), opt.num_class))

                    x_batch, x_batch_mask = prepare_data_for_emb(sents, opt)
                    _, loss, step, = sess.run(
                        [train_op, loss_, global_step],
                        feed_dict={
                            x_: x_batch,
                            x_mask_: x_batch_mask,
                            y_: x_labels,
                            keep_prob: opt.dropout,
                            class_penalty_: opt.class_penalty
                        })

                    if uidx % opt.valid_freq == 0:
                        train_correct = 0.0
                        # sample evaluate accuaccy on 500 sample data
                        kf_train = get_minibatches_idx(500,
                                                       opt.batch_size,
                                                       shuffle=True)
                        for _, train_index in kf_train:
                            train_sents = [train[t] for t in train_index]
                            train_labels = [train_lab[t] for t in train_index]
                            train_labels = np.array(train_labels)
                            train_labels = train_labels.reshape(
                                (len(train_labels), opt.num_class))
                            x_train_batch, x_train_batch_mask = prepare_data_for_emb(
                                train_sents, opt)
                            train_accuracy = sess.run(accuracy_,
                                                      feed_dict={
                                                          x_: x_train_batch,
                                                          x_mask_:
                                                          x_train_batch_mask,
                                                          y_: train_labels,
                                                          keep_prob: 1.0,
                                                          class_penalty_: 0.0
                                                      })

                            train_correct += train_accuracy * len(train_index)

                        train_accuracy = train_correct / 500

                        print("Iteration %d: Training loss %f " % (uidx, loss))
                        print("Train accuracy %f " % train_accuracy)

                        val_correct = 0.0
                        kf_val = get_minibatches_idx(len(val),
                                                     opt.batch_size,
                                                     shuffle=True)
                        for _, val_index in kf_val:
                            val_sents = [val[t] for t in val_index]
                            val_labels = [val_lab[t] for t in val_index]
                            val_labels = np.array(val_labels)
                            val_labels = val_labels.reshape(
                                (len(val_labels), opt.num_class))
                            x_val_batch, x_val_batch_mask = prepare_data_for_emb(
                                val_sents, opt)
                            val_accuracy = sess.run(accuracy_,
                                                    feed_dict={
                                                        x_: x_val_batch,
                                                        x_mask_:
                                                        x_val_batch_mask,
                                                        y_: val_labels,
                                                        keep_prob: 1.0,
                                                        class_penalty_: 0.0
                                                    })

                            val_correct += val_accuracy * len(val_index)

                        val_accuracy = val_correct / len(val)
                        print("Validation accuracy %f " % val_accuracy)

                        if val_accuracy > max_val_accuracy:
                            max_val_accuracy = val_accuracy

                            test_correct = 0.0

                            kf_test = get_minibatches_idx(len(test),
                                                          opt.batch_size,
                                                          shuffle=True)
                            for _, test_index in kf_test:
                                test_sents = [test[t] for t in test_index]
                                test_labels = [test_lab[t] for t in test_index]
                                test_labels = np.array(test_labels)
                                test_labels = test_labels.reshape(
                                    (len(test_labels), opt.num_class))
                                x_test_batch, x_test_batch_mask = prepare_data_for_emb(
                                    test_sents, opt)

                                test_accuracy = sess.run(accuracy_,
                                                         feed_dict={
                                                             x_: x_test_batch,
                                                             x_mask_:
                                                             x_test_batch_mask,
                                                             y_: test_labels,
                                                             keep_prob: 1.0,
                                                             class_penalty_:
                                                             0.0
                                                         })

                                test_correct += test_accuracy * len(test_index)
                            test_accuracy = test_correct / len(test)
                            print("Test accuracy %f " % test_accuracy)
                            max_test_accuracy = test_accuracy

                print("Epoch %d: Max Test accuracy %f" %
                      (epoch, max_test_accuracy))
                saver.save(sess, opt.save_path, global_step=epoch)
            print("Max Test accuracy %f " % max_test_accuracy)

        except KeyboardInterrupt:
            print('Training interupted')
            print("Max Test accuracy %f " % max_test_accuracy)
Ejemplo n.º 5
0
def main():
    # Prepare training and testing data
    opt = Options()
    main_Path = '/home/dell/桌面/GG/TDD/keyword/Our_method/dataset/'
    # load data
    if opt.dataset == 'yahoo':
        loadpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/yahoo.p"
        embpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/yahoo_glove.p"
        load_G_path = '/home/dell/PycharmProjects/NLP/Idea-1/Results/yahoo/08/cnn/yahoo_G.pickle'
        opt.num_class = 10
        opt.class_name = ['Society Culture',
                          'Science Mathematics',
                          'Health',
                          'Education Reference',
                          'Computers Internet',
                          'Sports',
                          'Business Finance',
                          'Entertainment Music',
                          'Family Relationships',
                          'Politics Government']
    elif opt.dataset == 'agnews':
        loadpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/ag_news.pickle"
        embpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/ag_news_glove.pickle"
        load_G_path = '/home/dell/PycharmProjects/NLP/Idea-1/Results/ag_news/08/cnn/ag_news_G.pickle'
        opt.num_class = 4
        opt.class_name = ['World',
                          'Sports',
                          'Business',
                          'Science']
    elif opt.dataset == 'dbpedia':
        loadpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/dbpedia.pickle"
        embpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/dbpedia_glove.pickle"
        load_G_path = '/home/dell/PycharmProjects/NLP/Idea-1/Results/dbpedia/08/cnn/dbpedia_G.pickle'
        opt.num_class = 14
        opt.class_name = ['Company',
                          'Educational Institution',
                          'Artist',
                          'Athlete',
                          'Office Holder',
                          'Mean Of Transportation',
                          'Building',
                          'Natural Place',
                          'Village',
                          'Animal',
                          'Plant',
                          'Album',
                          'Film',
                          'Written Work',
                          ]
    elif opt.dataset == 'yelp_full':
        loadpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/yelp_full.pickle"
        embpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/yelp_full_glove.pickle"
        load_G_path = '/home/dell/PycharmProjects/NLP/Idea-1/Results/yelp_full/08/cnn/yelp_full_G.pickle'
        opt.num_class = 5
        opt.class_name = ['worst',
                          'bad',
                          'middle',
                          'good',
                          'best']
    elif opt.dataset == 'yelp':
        loadpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/yelp.pickle"
        embpath = "/home/dell/PycharmProjects/NLP/Idea-1/New_leam_dataset/yelp_glove.pickle"
        load_G_path = '/home/dell/PycharmProjects/NLP/Idea-1/Results/yelp/08/cnn/yelp_G.pickle'
        opt.num_class = 2
        opt.class_name = ['bad',
                          'good']
    x = pickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2] #将单词由数字表示,已做了分词工作,且句子长度尚未统一
    train_lab, val_lab, test_lab = x[3], x[4], x[5]#label 采用one-hot编码形式表示
    wordtoix, ixtoword = x[6], x[7]

    #加载权重G
    G_train, G_val, G_test = pickle.load(open(load_G_path, "rb"))

    del x
    print("load data finished")

    train_lab = np.array(train_lab, dtype='float32')
    val_lab = np.array(val_lab, dtype='float32')
    test_lab = np.array(test_lab, dtype='float32')    
    opt.n_words = len(ixtoword)
    if opt.part_data:
        #np.random.seed(123)
        train_ind = np.random.choice(len(train), int(len(train)*opt.portion), replace=False)
        train = [train[t] for t in train_ind]
        train_lab = [train_lab[t] for t in train_ind]
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.GPUID)

    print(dict(opt))
    print('Total words: %d' % opt.n_words)

    try:
        opt.W_emb = np.array(pickle.load(open(embpath, 'rb')),dtype='float32')
        opt.W_class_emb = load_class_embedding( wordtoix, opt)
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.maxlen],name='x_')
        x_mask_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.maxlen],name='x_mask_')
        keep_prob = tf.placeholder(tf.float32,name='keep_prob')
        y_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.num_class],name='y_')
        class_penalty_ = tf.placeholder(tf.float32, shape=())
        G_our = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.maxlen, opt.num_class], name='G_our')
        seq_len = tf.placeholder(tf.int32, shape=[opt.batch_size], name='sque_sentence_num')
        accuracy_, loss_, train_op, W_norm_, global_step = emb_classifier(x_, x_mask_, y_, keep_prob, opt, class_penalty_, G_our, seq_len)
    uidx = 0
    max_val_accuracy = 0.
    max_test_accuracy = 0.
    val_acc = 0.

    config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, )
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                save_keys = tensors_key_in_file(opt.save_path)
                ss = set([var.name for var in t_vars]) & set([s + ":0" for s in save_keys.keys()])
                cc = {var.name: var for var in t_vars}
                # only restore variables with correct shape
                ss_right_shape = set([s for s in ss if cc[s].get_shape() == save_keys[s[:-2]]])

                loader = tf.train.Saver(var_list=[var for var in t_vars if var.name in ss_right_shape])
                loader.restore(sess, opt.save_path)

                print("Loading variables from '%s'." % opt.save_path)
                print("Loaded variables:" + str(ss))

            except:
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        try:
            for epoch in range(opt.max_epochs):
                print("Starting epoch %d" % epoch)
                kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
                for _, train_index in kf:
                    uidx += 1
                    sents = [train[t] for t in train_index]
                    G1 = [G_train[t] for t in train_index]
                    x_labels = [train_lab[t] for t in train_index]
                    x_labels = np.array(x_labels)
                    x_labels = x_labels.reshape((len(x_labels), opt.num_class))
                    x_batch, x_batch_mask, G_batch, seq_len_batch = prepare_data_for_emb(sents, G1, opt)

                    _, loss, step,  = sess.run([train_op, loss_, global_step], feed_dict={x_: x_batch, x_mask_: x_batch_mask, y_: x_labels, keep_prob: opt.dropout, class_penalty_:opt.class_penalty, G_our:G_batch, seq_len:seq_len_batch})

                    if uidx % opt.valid_freq == 0:
                        train_correct = 0.0
                        # sample evaluate accuaccy on 500 sample data
                        kf_train = get_minibatches_idx(500, opt.batch_size, shuffle=True)
                        for _, train_index in kf_train:
                            train_sents = [train[t] for t in train_index]
                            train_G = [G_train[t] for t in train_index]
                            train_labels = [train_lab[t] for t in train_index]
                            train_labels = np.array(train_labels)
                            train_labels = train_labels.reshape((len(train_labels), opt.num_class))
                            x_train_batch, x_train_batch_mask, G_train_batch, x_train_seq_len = prepare_data_for_emb(train_sents, train_G, opt)
                            train_accuracy = sess.run(accuracy_, feed_dict={x_: x_train_batch, x_mask_: x_train_batch_mask, y_: train_labels, keep_prob: 1.0, class_penalty_:0.0, G_our:G_train_batch, seq_len:x_train_seq_len})

                            train_correct += train_accuracy * len(train_index)

                        train_accuracy = train_correct / 500

                        print("Iteration %d: Training loss %f " % (uidx, loss))
                        print("Train accuracy %f " % train_accuracy)

                        if not os.path.exists(opt.dataset + '_Train_message.csv'):
                            with open(opt.dataset + '_Train_message.csv', 'a', newline='') as out:
                                # 设定写入模式
                                csv_write = csv.writer(out, dialect='excel')
                                # 写入具体内容
                                csv_write.writerow(["epoch", "Training loss", "Train accuracy"])
                                csv_write.writerow([epoch, loss, train_accuracy])
                        else:
                            with open(opt.dataset + '_Train_message.csv', 'a', newline='') as out:
                                # 设定写入模式
                                csv_write = csv.writer(out, dialect='excel')
                                csv_write.writerow([epoch, loss, train_accuracy])


                        val_correct = 0.0
                        kf_val = get_minibatches_idx(len(val), opt.batch_size, shuffle=True)
                        for _, val_index in kf_val:
                            val_sents = [val[t] for t in val_index]
                            val_Gs = [G_val[t] for t in val_index]
                            val_labels = [val_lab[t] for t in val_index]
                            val_labels = np.array(val_labels)
                            val_labels = val_labels.reshape((len(val_labels), opt.num_class))
                            x_val_batch, x_val_batch_mask, G_val_batch, x_val_seq_len = prepare_data_for_emb(val_sents, val_Gs, opt)
                            val_accuracy = sess.run(accuracy_, feed_dict={x_: x_val_batch, x_mask_: x_val_batch_mask, y_: val_labels, keep_prob: 1.0, class_penalty_:0.0, G_our:G_val_batch, seq_len:x_val_seq_len})

                            val_correct += val_accuracy * len(val_index)

                        val_accuracy = val_correct / len(val)
                        print("Validation accuracy %f " % val_accuracy)

                        #测试网络
                        test_correct = 0.0
                        kf_test = get_minibatches_idx(len(test), opt.batch_size, shuffle=True)
                        for _, test_index in kf_test:
                            test_sents = [test[t] for t in test_index]
                            test_Gs = [G_test[t] for t in test_index]
                            test_labels = [test_lab[t] for t in test_index]
                            test_labels = np.array(test_labels)
                            test_labels = test_labels.reshape((len(test_labels), opt.num_class))
                            x_test_batch, x_test_batch_mask, G_test_batch, x_test_seq_len = prepare_data_for_emb(test_sents, test_Gs,
                                                                                                 opt)

                            test_accuracy = sess.run(accuracy_, feed_dict={x_: x_test_batch, x_mask_: x_test_batch_mask, y_: test_labels, keep_prob: 1.0, class_penalty_: 0.0, G_our: G_test_batch, seq_len:x_test_seq_len})

                            test_correct += test_accuracy * len(test_index)
                        test_accuracy = test_correct / len(test)
                        print("Test accuracy %f " % test_accuracy)
                        # max_test_accuracy = test_accuracy
                        if test_accuracy > max_test_accuracy:
                            max_test_accuracy = test_accuracy
                            val_acc = val_accuracy
#                        max_test_accuracy = max(test_accuracy, max_test_accuracy)
#                        val_acc = val_accuracy

                        if val_accuracy > max_val_accuracy:
                            max_val_accuracy = val_accuracy
                            test_acc = test_accuracy

                        if not os.path.exists(opt.dataset + '_Classification_Results.csv'):
                            with open(opt.dataset + '_Classification_Results.csv', 'a', newline='') as out:
                                # 设定写入模式
                                csv_write = csv.writer(out, dialect='excel')
                                # 写入具体内容
                                csv_write.writerow(["epoch", "val_accuracy", "test_accuracy"])
                                csv_write.writerow([epoch, val_accuracy, test_accuracy])
                        else:
                            with open(opt.dataset + '_Classification_Results.csv', 'a', newline='') as out:
                                # 设定写入模式
                                csv_write = csv.writer(out, dialect='excel')
                                csv_write.writerow([epoch, val_accuracy, test_accuracy])

                print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy))
                saver.save(sess, opt.save_path, global_step=epoch)
                
            print("Max Test accuracy %f , val accuracy %f " % (max_test_accuracy, val_acc))
            print("Max val accuracy %f , test accuracy %f" % (max_val_accuracy, test_acc))
            with open(opt.dataset + '_Classification_Results.csv', 'a', newline='') as out:
                # 设定写入模式
                csv_write = csv.writer(out, dialect='excel')
                csv_write.writerow(['Max Test accuracy:', max_test_accuracy, 'val accuracy', val_acc])
                csv_write.writerow(['Max val accuracy:', max_val_accuracy, 'test accuracy', test_acc])
        except KeyboardInterrupt:
            print('Training interupted')
            print("Max Test accuracy %f " % max_test_accuracy)
            with open(opt.dataset + '_Classification_Results.csv', 'a', newline='') as out:
                # 设定写入模式
                csv_write = csv.writer(out, dialect='excel')
                csv_write.writerow(['Max Test accuracy:', max_test_accuracy, 'val accuracy', val_acc])
                csv_write.writerow(['Max val accuracy:', max_val_accuracy, 'test accuracy', test_acc])
def main():
    loadpath = "./yahoo4char.p"
    embpath = "./yahoo_glove.p"
    x = pickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    train_lab, val_lab, test_lab = x[3], x[4], x[5]
    wordtoix, ixtoword = x[6], x[7]
    del x
    print("load data finished")

    train_lab = np.array(train_lab, dtype='float32')
    val_lab = np.array(val_lab, dtype='float32')
    test_lab = np.array(test_lab, dtype='float32')

    opt = Options()
    opt.num_class = 10
    opt.class_name = ['Society Culture',
                      'Science Mathematics',
                      'Health',
                      'Education Reference',
                      'Computers Internet',
                      'Sports',
                      'Business Finance',
                      'Entertainment Music',
                      'Family Relationships',
                      'Politics Government']
    opt.n_words = len(ixtoword)
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.GPUID)
    opt.W_emb = np.array(pickle.load(open(embpath, 'rb'))[0], dtype='float32')
    opt.W_class_emb = load_class_embedding(wordtoix, opt)

    with tf.device('/gpu:0'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.maxlen],name='x_')
        x_mask_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.maxlen],name='x_mask_')
        keep_prob = tf.placeholder(tf.float32,name='keep_prob')
        y_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.num_class],name='y_')
        class_penalty_ = tf.placeholder(tf.float32, shape=())
        accuracy_, loss_, train_op, W_norm_, global_step = emb_classifier(x_, x_mask_, y_, keep_prob, opt, class_penalty_)

    uidx = 0
    max_val_accuracy = 0.
    max_test_accuracy = 0.

    config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, )
    # config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]
                x_labels = [train_lab[t] for t in train_index]
                x_labels = np.array(x_labels)
                x_labels = x_labels - 1
                x_labels = to_categorical(x_labels)

                x_batch, x_batch_mask = prepare_data_for_emb(sents, opt)
                _, loss, step, = sess.run([train_op, loss_, global_step],
                                          feed_dict={x_: x_batch, x_mask_: x_batch_mask, y_: x_labels,
                                                     keep_prob: opt.dropout, class_penalty_: opt.class_penalty})

                if uidx % opt.valid_freq == 0:
                    train_correct = 0.0
                    # sample evaluate accuaccy on 500 sample data
                    kf_train = get_minibatches_idx(500, opt.batch_size, shuffle=True)
                    for _, train_index in kf_train:
                        train_sents = [train[t] for t in train_index]
                        train_labels = [train_lab[t] for t in train_index]
                        train_labels = np.array(train_labels)
                        train_labels = train_labels - 1
                        train_labels = to_categorical(train_labels)
                        x_train_batch, x_train_batch_mask = prepare_data_for_emb(train_sents, opt)
                        train_accuracy = sess.run(accuracy_, feed_dict={x_: x_train_batch, x_mask_: x_train_batch_mask,
                                                                        y_: train_labels, keep_prob: 1.0,
                                                                        class_penalty_: 0.0})

                        train_correct += train_accuracy * len(train_index)

                    train_accuracy = train_correct / 500

                    print("Iteration %d: Training loss %f " % (uidx, loss))
                    print("Train accuracy %f " % train_accuracy)

                    val_correct = 0.0
                    kf_val = get_minibatches_idx(len(val), opt.batch_size, shuffle=True)
                    for _, val_index in kf_val:
                        val_sents = [val[t] for t in val_index]
                        val_labels = [val_lab[t] for t in val_index]
                        val_labels = np.array(val_labels)
                        val_labels = val_labels - 1
                        val_labels = to_categorical(val_labels)
                        x_val_batch, x_val_batch_mask = prepare_data_for_emb(val_sents, opt)
                        val_accuracy = sess.run(accuracy_, feed_dict={x_: x_val_batch, x_mask_: x_val_batch_mask,
                                                                      y_: val_labels, keep_prob: 1.0,
                                                                      class_penalty_: 0.0})

                        val_correct += val_accuracy * len(val_index)

                    val_accuracy = val_correct / len(val)
                    print("Validation accuracy %f " % val_accuracy)

                    if val_accuracy > max_val_accuracy:
                        max_val_accuracy = val_accuracy

                        test_correct = 0.0

                        kf_test = get_minibatches_idx(len(test), opt.batch_size, shuffle=True)
                        for _, test_index in kf_test:
                            test_sents = [test[t] for t in test_index]
                            test_labels = [test_lab[t] for t in test_index]
                            test_labels = np.array(test_labels)
                            test_labels = test_labels - 1
                            test_labels = to_categorical(test_labels)
                            x_test_batch, x_test_batch_mask = prepare_data_for_emb(test_sents, opt)

                            test_accuracy = sess.run(accuracy_, feed_dict={x_: x_test_batch, x_mask_: x_test_batch_mask,
                                                                           y_: test_labels, keep_prob: 1.0,
                                                                           class_penalty_: 0.0})

                            test_correct += test_accuracy * len(test_index)
                        test_accuracy = test_correct / len(test)
                        print("Test accuracy %f " % test_accuracy)
                        max_test_accuracy = test_accuracy

            print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy))
            saver.save(sess, opt.save_path, global_step=epoch)
        print("Max Test accuracy %f " % max_test_accuracy)

    print('Training interupted')
    print("Max Test accuracy %f " % max_test_accuracy)