示例#1
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    _, vocab_freq = semeval.load_vocab_and_freq()
    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        semeval_train_iter, semeval_test_iter = semeval.read_tfrecord(
            FLAGS.num_epochs, FLAGS.batch_size)
        model_name = 'cnn-%d-%d' % (FLAGS.word_dim, FLAGS.num_epochs)
        semeval_train = semeval_train_iter.get_next()
        semeval_test = semeval_test_iter.get_next()
        m_train, m_valid = cnn_model.build_train_valid_model(
            model_name, word_embed, vocab_freq, semeval_train, semeval_test,
            FLAGS.is_adv, FLAGS.is_test)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            # for t in sess.run(m_train.layers):
            #   print(t.shape)
            # exit()

            if FLAGS.is_test:
                test(sess, m_valid, semeval_test_iter)
            else:
                train_semeval(sess, m_train, m_valid, semeval_test_iter)
示例#2
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        train_iter, test_iter = fudan.read_tfrecord(FLAGS.num_epochs,
                                                    FLAGS.batch_size)
        train_data = train_iter.get_next()
        test_data = test_iter.get_next()

        m_train, m_valid = mtl_model.build_train_valid_model(
            'fudan-mtl', word_embed, train_data, test_data, FLAGS.adv)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.test:
                test(sess, models)
            else:
                train(sess, m_train, m_valid, test_iter)
示例#3
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        all_train = []
        all_test = []
        data_iter = fudan.read_tfrecord(FLAGS.num_epochs, FLAGS.batch_size)
        for task_id, (train_data, test_data) in enumerate(data_iter):
            task_name = fudan.get_task_name(task_id)
            all_train.append((task_name, train_data))
            all_test.append((task_name, test_data))

        model_name = 'fudan-mtl'
        if FLAGS.adv:
            model_name += '-adv'
        m_train, m_valid = mtl_model.build_train_valid_model(
            model_name, word_embed, all_train, all_test, FLAGS.adv, FLAGS.test)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.test:
                test(sess, m_valid)
            else:
                train(sess, m_train, m_valid)
示例#4
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        semeval_train_iter, semeval_test_iter = semeval.read_tfrecord(
            FLAGS.num_epochs, FLAGS.batch_size)
        dbpedia_train_iter, dbpedia_test_iter = dbpedia.read_tfrecord(
            FLAGS.num_epochs, FLAGS.batch_size)
        model_name = 'mtl-dbpedia-%d' % FLAGS.word_dim
        semeval_train = semeval_train_iter.get_next()
        semeval_test = semeval_test_iter.get_next()
        dbpedia_train = dbpedia_train_iter.get_next()
        dbpedia_test = dbpedia_test_iter.get_next()
        m_train, m_valid = mtl_model.build_train_valid_model(
            model_name, word_embed, semeval_train, semeval_test, dbpedia_train,
            dbpedia_test, FLAGS.is_mtl, FLAGS.is_adv, FLAGS.test)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.test:
                test(sess, m_valid, semeval_test_iter)
            else:
                # train_dbpedia(sess, m_train, m_valid, dbpedia_test_iter)
                train_semeval(sess, m_train, m_valid, semeval_test_iter)
示例#5
0
def run_model(num_examples, is_train, inspect_data=False, plot=False):
    if FLAGS.subword:
        word_embed = None
    else:
        word_embed = util.load_embedding(word_dim=FLAGS.hidden_size)

    name = "train" if is_train else "eval"

    with tf.Graph().as_default():

        global_step = tf.train.get_or_create_global_step()
        model = mtl_model.MTLModel(word_embed, input_fn(is_train=is_train), FLAGS.adv, is_train=is_train)
        model.build_train_op()
        model.set_saver(model_name())

        with tf.Session() as sess:
            try:
                init_op = tf.group(tf.global_variables_initializer(),
                                   tf.local_variables_initializer())  # for file queue
                sess.run(init_op)
                model.restore(sess)
            except Exception as e:
                tf.logging.warning("restore failed: {}".format(str(e)))

            summary_prefix = os.path.join(get_logdir(), model_name())
            writer = tf.summary.FileWriter(summary_prefix + '/{}'.format(name), sess.graph)

            n_task = len(model.tensors)
            batches = num_examples / FLAGS.batch_size

            merged = model.merged_summary(name)

            all_loss, all_acc = 0., 0.

            if inspect_data:
                eval_errors = collections.defaultdict(list)

            for batch in range(int(batches)):
                if inspect_data:
                    eval_fetch = [model.tensors, model.data, model.alignments, model.pred, model.separate_acc]
                    res, data, align, pred, separate_acc = sess.run(eval_fetch)  # res = [[acc], [loss]]

                    if plot:
                        inspect(data, align, pred)

                    if FLAGS.vader:
                        for i, ((acc, _), (private_acc, shared_acc, vader_acc)) in enumerate(zip(res, separate_acc.values())):
                            eval_errors[fudan.get_task_name(i)].append([float(acc), float(private_acc), float(shared_acc), float(vader_acc)])
                    else:
                        for i, ((acc, _), (private_acc, shared_acc)) in enumerate(zip(res, separate_acc.values())):
                            eval_errors[fudan.get_task_name(i)].append(
                                [float(acc), float(private_acc), float(shared_acc)])

                else:
                    if is_train:
                        train_fetch = [model.tensors, model.train_ops, merged, global_step]
                        res, _, summary, gs = sess.run(train_fetch)  # res = [[acc], [loss]]
                        writer.add_summary(summary, gs)
                    else:
                        eval_fetch = [model.tensors, merged, global_step]
                        res, summary, gs = sess.run(eval_fetch)  # res = [[acc], [loss]]
                        global eval_step
                        writer.add_summary(summary, eval_step)
                        eval_step = eval_step + 1



                res = np.array(res)

                all_loss += sum(res[:, 1].astype(np.float))
                all_acc += sum(res[:, 0].astype(np.float))


            all_loss /= (batches * n_task)
            all_acc /= (batches * n_task)

            if is_train:
                model.save(sess, global_step)

            if inspect_data:
                columns = ['err', 'private_err', 'shared_err', 'vader_err'] if FLAGS.vader else ['err', 'private_err', 'shared_err', 'vader_err']
                df = 1- pd.DataFrame(
                    data=np.array(list(eval_errors.values())).mean(axis=1),
                    index=list(eval_errors.keys()),
                    columns=columns
                )

                print(df)
                print(df.mean())

            return all_loss, all_acc