예제 #1
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        all_train = []
        all_test = []
        data_iter = fudan.read_tfrecord(FLAGS.num_epochs, FLAGS.batch_size)
        for task_id, (train_data, test_data) in enumerate(data_iter):
            task_name = fudan.get_task_name(task_id)
            all_train.append((task_name, train_data))
            all_test.append((task_name, test_data))

        model_name = 'fudan-mtl'
        if FLAGS.adv:
            model_name += '-adv'
        m_train, m_valid = mtl_model.build_train_valid_model(
            model_name, word_embed, all_train, all_test, FLAGS.adv, FLAGS.test)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.test:
                test(sess, m_valid)
            else:
                train(sess, m_train, m_valid)
예제 #2
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        train_iter, test_iter = fudan.read_tfrecord(FLAGS.num_epochs,
                                                    FLAGS.batch_size)
        train_data = train_iter.get_next()
        test_data = test_iter.get_next()

        m_train, m_valid = mtl_model.build_train_valid_model(
            'fudan-mtl', word_embed, train_data, test_data, FLAGS.adv)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.test:
                test(sess, models)
            else:
                train(sess, m_train, m_valid, test_iter)
예제 #3
0
def main(_):
    if FLAGS.build_data:
        build_data()
        exit()

    word_embed = util.load_embedding(word_dim=FLAGS.word_dim)
    with tf.Graph().as_default():
        semeval_train_iter, semeval_test_iter = semeval.read_tfrecord(
            FLAGS.num_epochs, FLAGS.batch_size)
        dbpedia_train_iter, dbpedia_test_iter = dbpedia.read_tfrecord(
            FLAGS.num_epochs, FLAGS.batch_size)
        model_name = 'mtl-dbpedia-%d' % FLAGS.word_dim
        semeval_train = semeval_train_iter.get_next()
        semeval_test = semeval_test_iter.get_next()
        dbpedia_train = dbpedia_train_iter.get_next()
        dbpedia_test = dbpedia_test_iter.get_next()
        m_train, m_valid = mtl_model.build_train_valid_model(
            model_name, word_embed, semeval_train, semeval_test, dbpedia_train,
            dbpedia_test, FLAGS.is_mtl, FLAGS.is_adv, FLAGS.test)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.test:
                test(sess, m_valid, semeval_test_iter)
            else:
                # train_dbpedia(sess, m_train, m_valid, dbpedia_test_iter)
                train_semeval(sess, m_train, m_valid, semeval_test_iter)
예제 #4
0
def main(_):
  with tf.Graph().as_default():
    train_data, test_data, word_embed = base_reader.inputs()

    # sv = tf.train.Supervisor()
    # with sv.managed_session() as sess:
    #   print('='*80)
    #   for i in range(10):
    #     arr = sess.run(test_data)
    #     print(train_data[2].shape, arr[2].shape)
    #     print(arr[2][0])
    #   exit()

    
    if FLAGS.model == 'cnn':
      m_train, m_valid = cnn_model.build_train_valid_model(word_embed, 
                                                      train_data, test_data)
    elif FLAGS.model == 'mtl':
      m_train, m_valid = mtl_model.build_train_valid_model(word_embed, 
                                                      train_data, test_data)
    
    m_train.set_saver(FLAGS.model)
    
    init_op = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())# for file queue

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.9 # 占用GPU90%的显存 
    config.gpu_options.allow_growth = True
    
    # sv finalize the graph
    with tf.Session(config=config) as sess:
      sess.run(init_op)
      print('='*80)

      if FLAGS.trace:
        trace_runtime(sess, m_train)
      elif FLAGS.test:
        test(sess, m_valid)
      else:
        train(sess, m_train, m_valid)