Exemplo n.º 1
0
def test():
    model_name = 'TD_3LSTM_0.7725_0.8020_14res2.pt'
    # f_test = "data/%s/test_all.txt" % ds
    f_test = "polarity_level/data_aspect/{0}/test/term_all.txt".format('res')
    test_text, test_t, test_ow = load_data(filename=f_test)
    f_w2v = "data/%s/embedding_all_glove300.txt" % ds
    W, word2index = load_w2v(f_w2v)
    model = NeuralTagger_elmo()
    rnn = torch.load("backup3/%s" % model_name)
    result = model.predict(rnn, (test_text, test_t, test_ow), word2index, args)
    test_file = "case_study/" + model_name[0:-3] + "_test.txt"
    fw = codecs.open(test_file, 'w', encoding='utf-8')
    fw2 = codecs.open("data_aspect/{0}/test/ow.txt".format('res'),
                      'w',
                      encoding='utf-8')
    # print(result)
    assert len(result) == len(test_text)
    for s, t, p, g in zip(test_text, test_t, result, test_ow):
        t = ' '.join([str(i) for i in t])
        p = p.tolist()
        p = ' '.join([str(i) for i in p])
        # print(p)
        # print(g)
        g = ' '.join([str(i) for i in g])
        fw.write(' '.join(s) + '\t' + t + '\t' + p + '\t' + g + '\t' +
                 str(p == g) + '\n')
        fw2.write(p + '\n')
Exemplo n.º 2
0
def main(_):
    w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, 200)
    eval_x = loadEvalSample(FLAGS.eval_data_file, word2id)
    eval_x = pad_sequences(eval_x, maxlen=FLAGS.sample_len, value=0.)
    eval_sample_size = len(eval_x)
    print('Evaluation samples size:' + str(eval_sample_size))
    id2label_map = loadId2LabelMap(FLAGS.label_map)

    checkpoint_file = 'zhihu/TextACNN/runs/1501587555/checkpoints/model-130411'  #tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    print('checkpoint_file:' + checkpoint_file)
    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        with sess.as_default(), tf.device('/gpu:0'):
            saver = tf.train.import_meta_graph(
                '{}.meta'.format(checkpoint_file))
            saver.restore(sess, checkpoint_file)
            input_x = graph.get_operation_by_name('input_x').outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                'dropout_keep_prob').outputs[0]
            logits = graph.get_operation_by_name('output/logits').outputs[0]
            all_predictions = []
            all_scores = []
            end_pos = 0
            for start, end in zip(
                    range(0, eval_sample_size, FLAGS.batch_size),
                    range(FLAGS.batch_size, eval_sample_size,
                          FLAGS.batch_size)):
                #print(eval_x)
                feed_dict = {input_x: eval_x[start:end], dropout_keep_prob: 1.}
                batch_logits = sess.run(logits, feed_dict)
                batch_predictions, batch_scores = getTopNPredictions(
                    batch_logits, id2label_map, FLAGS.top_predictions)
                all_predictions.extend(batch_predictions)
                all_scores.extend(batch_scores)
                end_pos = end
            if end_pos < eval_sample_size:
                feed_dict = {input_x: eval_x[end:], dropout_keep_prob: 1.}
                batch_logits = sess.run(logits, feed_dict)
                batch_predictions, batch_scores = getTopNPredictions(
                    batch_logits, id2label_map, FLAGS.top_predictions)
                all_predictions.extend(batch_predictions)
                all_scores.extend(batch_scores)
            print(len(all_predictions))
    cnt = 0
    id2question_map = dict()
    for line in open(FLAGS.raw_eval_file).readlines():
        parts = line.strip().split('\t')
        id2question_map[cnt] = parts[0]
        cnt += 1
    fp = open(FLAGS.prediction_file, 'w')
    debug_fp = open(FLAGS.prediction_debug_file, 'w')
    for i in range(len(all_predictions)):
        res_str = id2question_map[i]
        res_debug_str = id2question_map[i]
        for j in range(len(all_predictions[i])):
            #if FLAGS.debug:
            res_debug_str += ',' + all_predictions[i][j] + ':' + str(
                all_scores[i][j])
            #else:
            res_str += ',' + all_predictions[i][j]
        fp.write(res_str + '\n')
        debug_fp.write(res_debug_str + '\n')
        cnt += 1
    fp.flush()
    fp.close()
Exemplo n.º 3
0
def main(_):
    print('Loading word2vec model finished:%s' % (FLAGS.word_embedding_file))
    #w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, FLAGS.embedding_size)
    w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, 256)
    print('Load word2vec model finished')
    print('Loading train/valid samples:%s' % (FLAGS.training_data))
    train_x, train_y, valid_x, valid_y = loadSamples(FLAGS.training_data, FLAGS.label_file, FLAGS.label_map, word2id, FLAGS.valid_rate, FLAGS.num_classes)
    print('Load train/valid samples finished')
    #train_x = pad_sequences(train_x, maxlen=FLAGS.sample_len, value = 0.)
    #valid_x = pad_sequences(valid_x, maxlen=FLAGS.sample_len, value = 0.)
    labelNumStats(valid_y)
    
    train_sample_size = len(train_x)
    dev_sample_size = len(valid_x)
    print('Training sample size:%d' % (train_sample_size))
    print('Valid sample size:%d' % (dev_sample_size))

    timestamp = str(int(time.time()))
    runs_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs'))
    if not os.path.exists(runs_dir):
        os.makedirs(runs_dir)
    out_dir = os.path.abspath(os.path.join(runs_dir, timestamp))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')


    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        with sess.as_default(), tf.device('/gpu:0'):
            text_cnn = TextCNN(
                sample_len = FLAGS.sample_len,
                num_classes = FLAGS.num_classes,
                learning_rate = FLAGS.learning_rate,
                decay_steps = FLAGS.decay_steps,
                decay_rate = FLAGS.decay_rate,
                embedding_size = FLAGS.embedding_size,
                filter_sizes = list(map(int, FLAGS.filter_sizes.split(','))),
                num_filters = FLAGS.num_filters,
                l2_reg_lambda = FLAGS.l2_reg_lambda,
                w2v_model = w2v_model)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            train_summary_dir = os.path.join(out_dir, 'summaries', 'train')
            dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')
            loss_summary = tf.summary.scalar('loss', text_cnn.loss_val)
            acc_summary = tf.summary.scalar('accuracy', text_cnn.accuracy)
            train_summary_op = tf.summary.merge([loss_summary, acc_summary])
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
            dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
            dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

            sess.run(tf.global_variables_initializer())
            total_loss = 0.
            total_acc = 0.
            total_step = 0.
            best_valid_acc = 0.
            best_valid_loss = 1000.
            best_valid_zhihu_score = 0.
            this_step_valid_acc = 0.
            this_step_valid_loss = 0.
            this_step_zhihu_score = 0.
            valid_loss_summary = tf.summary.scalar('loss', this_step_valid_loss)
            valid_acc_summary = tf.summary.scalar('accuracy', this_step_valid_acc)
            valid_zhihu_score_summary = tf.summary.scalar('zhihu_score', this_step_zhihu_score)
            valid_summary_op = tf.summary.merge([valid_loss_summary, valid_acc_summary, valid_zhihu_score_summary])
            for epoch in range(0, FLAGS.num_epochs):
                print('epoch:' + str(epoch))
                if FLAGS.shuffle:
                    shuffle_indices = np.random.permutation(np.arange(train_sample_size))
                    train_x = train_x[shuffle_indices]
                    train_y = train_y[shuffle_indices]
                batch_step = 0
                batch_loss = 0.
                batch_acc = 0.
                for start, end in zip(range(0, train_sample_size, FLAGS.batch_size), range(FLAGS.batch_size, train_sample_size, FLAGS.batch_size)):
                    batch_input_x = train_x[start:end]
                    batch_input_y = train_y[start:end]
                    batch_input_x = pad_sequences(batch_input_x, maxlen=FLAGS.sample_len, value = 0.)
                    feed_dict = {
                        text_cnn.input_x: batch_input_x,
                        text_cnn.input_y: batch_input_y,
                        text_cnn.dropout_keep_prob: 0.39
                    }
                    loss, acc, step, summaries, _ = sess.run([text_cnn.loss_val, text_cnn.accuracy, text_cnn.global_step, train_summary_op, text_cnn.train_op], feed_dict)
                    train_summary_writer.add_summary(summaries, step)
                    total_loss += loss
                    total_acc += acc
                    batch_loss += loss
                    batch_acc += acc
                    batch_step += 1
                    total_step += 1.
                    if batch_step % FLAGS.print_stats_every == 0:
                        time_str = datetime.datetime.now().isoformat()
                        print('[%s]Epoch:%d\tBatch_Step:%d\tTrain_Loss:%.4f/%.4f/%.4f\tTrain_Accuracy:%.4f/%.4f/%.4f' % (time_str, epoch, batch_step, loss, batch_loss / batch_step, total_loss / total_step, acc, batch_acc / batch_step, total_acc / total_step))
                    if batch_step % FLAGS.evaluate_every == 0 and total_step > 20000:
                        eval_loss = 0.
                        eval_acc = 0.
                        eval_step = 0
                        for start, end in zip(range(0, dev_sample_size, FLAGS.batch_size), range(FLAGS.batch_size, dev_sample_size, FLAGS.batch_size)):
                            batch_input_x = valid_x[start:end]
                            batch_input_x = pad_sequences(batch_input_x, maxlen=FLAGS.sample_len, value = 0.)
                            feed_dict = {
                                text_cnn.input_x: batch_input_x,
                                text_cnn.input_y: valid_y[start:end],
                                text_cnn.dropout_keep_prob: 1.
                            }
                            #step, summaries, loss, acc, logits = sess.run([text_cnn.global_step, dev_summary_op, text_cnn.loss_val, text_cnn.accuracy, text_cnn.logits], feed_dict)
                            step, loss, acc, logits = sess.run([text_cnn.global_step, text_cnn.loss_val, text_cnn.accuracy, text_cnn.logits], feed_dict)
                            zhihuStats(logits, valid_y[start:end])
                            eval_loss += loss
                            eval_acc += acc
                            eval_step += 1
                        this_step_zhihu_score = calZhihuScore()
                        time_str = datetime.datetime.now().isoformat()
                        print('[%s]Eval_Loss:%.4f\tEval_Accuracy:%.4f\tZhihu_Score:%.4f' % (time_str, eval_loss / eval_step, eval_acc / eval_step, this_step_zhihu_score))
                        this_step_valid_acc = eval_acc / eval_step
                        this_step_valid_loss = eval_loss / eval_step
                        summaries, step = sess.run([valid_summary_op, text_cnn.global_step])
                        dev_summary_writer.add_summary(summaries, step)
                    if batch_step % FLAGS.checkpoint_every == 0 and total_step > 20000:
                        if not FLAGS.save_best_model:
                            path = saver.save(sess, checkpoint_prefix, global_step=step)
                            print('Saved model checkpoint to %s' % path)
                        elif this_step_zhihu_score > best_valid_zhihu_score:
                            path = saver.save(sess, checkpoint_prefix, global_step=step)
                            print('Saved best zhihu_score model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score))
                            best_valid_zhihu_score = this_step_zhihu_score
                        elif this_step_valid_acc > best_valid_acc:
                            path = saver.save(sess, checkpoint_prefix, global_step=step)
                            print('Saved best acc model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score))
                            best_valid_acc = this_step_valid_acc
                        elif this_step_valid_loss < best_valid_loss:
                            path = saver.save(sess, checkpoint_prefix, global_step=step)
                            print('Saved best loss model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score))
                            best_valid_loss = this_step_valid_loss
Exemplo n.º 4
0
def main(_):
    print('Loading word2vec model finished:%s' % (FLAGS.word_embedding_file))
    w2v_model, word2id = load_w2v(FLAGS.word_embedding_file,
                                  FLAGS.embedding_size)
    #w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, 256)
    print('Load word2vec model finished')
    print('Loading train/valid samples:%s' % (FLAGS.training_data))
    train_x, train_title, train_y, valid_x, valid_title, valid_y, label_map = loadSamples(
        FLAGS.training_data, FLAGS.label_file, FLAGS.label_map,
        FLAGS.eval_data_file, word2id, FLAGS.valid_rate, FLAGS.num_classes,
        FLAGS.sent_len, FLAGS.sent_len, FLAGS.doc_len)
    assert (len(train_x) == len(train_title) and len(train_x) == len(train_y))
    print('Load train/valid samples finished')

    mem = getTopicEmbedding(FLAGS.topic_info, w2v_model, word2id, label_map,
                            FLAGS.embedding_size)
    print('mem', mem.shape)

    labelNumStats(valid_y)

    train_sample_size = len(train_x)
    dev_sample_size = len(valid_x)
    print('Training sample size:%d' % (train_sample_size))
    print('Valid sample size:%d' % (dev_sample_size))

    timestamp = str(int(time.time()))
    runs_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs'))
    if not os.path.exists(runs_dir):
        os.makedirs(runs_dir)
    out_dir = os.path.abspath(os.path.join(runs_dir, timestamp))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')

    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        #sess = tf.Session()
        with sess.as_default(), tf.device('/gpu:0'):
            mem_net = MemNet(num_classes=FLAGS.num_classes,
                             learning_rate=FLAGS.learning_rate,
                             decay_steps=FLAGS.decay_steps,
                             decay_rate=FLAGS.decay_rate,
                             l2_reg_lambda=FLAGS.l2_reg_lambda,
                             embedding_size=FLAGS.embedding_size,
                             doc_len=FLAGS.doc_len,
                             sent_len=FLAGS.sent_len,
                             w2v_model=w2v_model,
                             mem_size=FLAGS.embedding_size * 2,
                             mem_model=mem,
                             rnn_hidden_size=FLAGS.rnn_hidden_size,
                             fc_layer_size=FLAGS.fc_layer_size,
                             title_len=FLAGS.sent_len)

            print('delete word2id')
            word2id = {}
            print('delete w2v_model')
            w2v_model = []

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)
            train_summary_dir = os.path.join(out_dir, 'summaries', 'train')
            dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')
            loss_summary = tf.summary.scalar('loss', mem_net.loss_val)
            acc_summary = tf.summary.scalar('accuracy', mem_net.accuracy)
            train_summary_op = tf.summary.merge([loss_summary, acc_summary])
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)
            dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
            dev_summary_writer = tf.summary.FileWriter(dev_summary_dir,
                                                       sess.graph)

            sess.run(tf.global_variables_initializer())
            total_loss = 0.
            total_acc = 0.
            total_step = 0.
            best_valid_acc = 0.
            best_valid_loss = 1000.
            best_valid_zhihu_score = 0.
            this_step_valid_acc = 0.
            this_step_valid_loss = 0.
            this_step_zhihu_score = 0.
            valid_loss_summary = tf.summary.scalar('loss',
                                                   this_step_valid_loss)
            valid_acc_summary = tf.summary.scalar('accuracy',
                                                  this_step_valid_acc)
            valid_zhihu_score_summary = tf.summary.scalar(
                'zhihu_score', this_step_zhihu_score)
            valid_summary_op = tf.summary.merge([
                valid_loss_summary, valid_acc_summary,
                valid_zhihu_score_summary
            ])
            for epoch in range(0, FLAGS.num_epochs):
                print('epoch:' + str(epoch))
                if FLAGS.shuffle:
                    shuffle_indices = np.random.permutation(
                        np.arange(train_sample_size))
                    train_x = train_x[shuffle_indices]
                    train_title = train_title[shuffle_indices]
                    train_y = train_y[shuffle_indices]
                batch_step = 0
                batch_loss = 0.
                batch_acc = 0.
                for start, end in zip(
                        range(0, train_sample_size, FLAGS.batch_size),
                        range(FLAGS.batch_size, train_sample_size,
                              FLAGS.batch_size)):
                    batch_input_x = train_x[start:end]
                    batch_input_title = train_title[start:end]
                    batch_input_y = train_y[start:end]
                    batch_input_x, mask, batch_input_title, title_mask = paddingX(
                        batch_input_x, batch_input_title, FLAGS.sent_len,
                        FLAGS.sent_len, FLAGS.doc_len)
                    batch_input_y = paddingY(batch_input_y, FLAGS.num_classes)

                    feed_dict = {
                        mem_net.input_x: batch_input_x,
                        mem_net.input_title: batch_input_title,
                        mem_net.input_y: batch_input_y,
                        mem_net.mask: mask,
                        mem_net.title_mask: title_mask,
                        mem_net.l1_dropout_keep_prob:
                        FLAGS.l1_dropout_keep_prob,
                        mem_net.l2_dropout_keep_prob:
                        FLAGS.l2_dropout_keep_prob
                    }
                    loss, acc, step, summaries, _ = sess.run([
                        mem_net.loss_val, mem_net.accuracy,
                        mem_net.global_step, train_summary_op, mem_net.train_op
                    ], feed_dict)
                    train_summary_writer.add_summary(summaries, step)
                    total_loss += loss
                    total_acc += acc
                    batch_loss += loss
                    batch_acc += acc
                    batch_step += 1
                    total_step += 1.
                    if batch_step % FLAGS.print_stats_every == 0:
                        time_str = datetime.datetime.now().isoformat()
                        print(
                            '[%s]Epoch:%d\tBatch_Step:%d\tTrain_Loss:%.4f/%.4f/%.4f\tTrain_Accuracy:%.4f/%.4f/%.4f'
                            % (time_str, epoch, batch_step, loss, batch_loss /
                               batch_step, total_loss / total_step, acc,
                               batch_acc / batch_step, total_acc / total_step))
                    if batch_step % FLAGS.evaluate_every == 0 and total_step > 0:
                        eval_loss = 0.
                        eval_acc = 0.
                        eval_step = 0
                        for start, end in zip(
                                range(0, dev_sample_size, FLAGS.batch_size),
                                range(FLAGS.batch_size, dev_sample_size,
                                      FLAGS.batch_size)):
                            batch_input_x = valid_x[start:end]
                            batch_input_title = valid_title[start:end]
                            batch_input_x, mask, batch_input_title, title_mask = paddingX(
                                batch_input_x, batch_input_title,
                                FLAGS.sent_len, FLAGS.sent_len, FLAGS.doc_len)
                            batch_input_y = valid_y[start:end]
                            batch_input_y = paddingY(batch_input_y,
                                                     FLAGS.num_classes)
                            feed_dict = {
                                mem_net.input_x:
                                batch_input_x,
                                mem_net.input_title:
                                batch_input_title,
                                mem_net.input_y:
                                batch_input_y,
                                mem_net.mask:
                                mask,
                                mem_net.title_mask:
                                title_mask,
                                mem_net.l1_dropout_keep_prob:
                                FLAGS.l1_dropout_keep_prob,
                                mem_net.l2_dropout_keep_prob:
                                FLAGS.l2_dropout_keep_prob
                            }
                            step, summaries, loss, acc, logits = sess.run([
                                mem_net.global_step, dev_summary_op,
                                mem_net.loss_val, mem_net.accuracy,
                                mem_net.logits
                            ], feed_dict)
                            dev_summary_writer.add_summary(summaries, step)
                            zhihuStats(logits,
                                       batch_input_y)  #valid_y[start:end])
                            eval_loss += loss
                            eval_acc += acc
                            eval_step += 1
                        this_step_zhihu_score = calZhihuScore()
                        time_str = datetime.datetime.now().isoformat()
                        print(
                            '[%s]Eval_Loss:%.4f\tEval_Accuracy:%.4f\tZhihu_Score:%.4f'
                            % (time_str, eval_loss / eval_step,
                               eval_acc / eval_step, this_step_zhihu_score))
                        this_step_valid_acc = eval_acc / eval_step
                        this_step_valid_loss = eval_loss / eval_step
                        #dev_summary_writer.add_summary(summaries, step)
                    if batch_step % FLAGS.checkpoint_every == 0 and total_step > 0:
                        if not FLAGS.save_best_model:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print('Saved model checkpoint to %s' % path)
                        elif this_step_zhihu_score > best_valid_zhihu_score:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved best zhihu_score model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))
                            best_valid_zhihu_score = this_step_zhihu_score
                        elif this_step_valid_acc > best_valid_acc:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved best acc model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))
                            best_valid_acc = this_step_valid_acc
                        elif this_step_valid_loss < best_valid_loss:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved best loss model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))
                            best_valid_loss = this_step_valid_loss
                        elif total_step % 6000 == 0:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))