コード例 #1
0
def main(_):
    config = config_lib.get_config()
    embed = dataset.Embed(config.out_dir, config.trimmed_embed300_file,
                          config.vocab_file)
    ini_word_embed = embed.load_embedding()

    semeval_record = semeval_v2.SemEvalCleanedRecordData(
        None, config.out_dir, config.semeval_train_record,
        config.semeval_test_record)

    vocab_tags = dataset.Label(config.semeval_dir, config.semeval_tags_file)

    with tf.Graph().as_default():
        train_iter = semeval_record.train_data(config.hparams.num_epochs,
                                               config.hparams.batch_size)
        test_iter = semeval_record.test_data(1, config.hparams.batch_size)

        train_data = train_iter.get_next()
        test_data = test_iter.get_next()

        m_train, m_valid = rnn_model.build_train_valid_model(
            config, ini_word_embed, train_data, test_data)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        for tensor in tf.trainable_variables():
            tf.logging.info(tensor.op.name)

        with tf.Session(config=sess_config) as sess:
            sess.run(init_op)
            print('=' * 80)

            # for batch in range(3):
            #   # (labels, lengths, sentence, tags) = sess.run(train_data)
            #   # print(sentence.shape, tags.shape)
            #   l, w = sess.run([onehot_tags, weights])
            #   print(l.shape, w.shape)
            #   print(w)

            # # sess.run(test_iter.initializer)
            # # for batch in range(28):
            # #   (labels, lengths, sentence, tags) = sess.run(test_data)
            # #   print(sentence.shape, tags.shape)
            # exit()

            if FLAGS.test:
                test(sess, m_valid, test_iter, vocab_tags)
            else:
                train_semeval(config, sess, m_train, m_valid, test_iter,
                              vocab_tags)
コード例 #2
0
def main(_):
    vocab_mgr = dataset.VocabMgr()
    word_embed = vocab_mgr.load_embedding()
    semeval_record = semeval_v2.SemEvalCleanedRecordData(None)

    # load dataset
    train_data = semeval_record.train_data(FLAGS.num_epochs, FLAGS.batch_size)
    test_data = semeval_record.test_data(1, FLAGS.batch_size)

    # model_name = 'cnn-%d-%d' % (FLAGS.word_dim, FLAGS.num_epochs)
    model = cnn_model.CNNModel(word_embed, FLAGS.is_adv)

    # for tensor in tf.trainable_variables():
    #   tf.logging.info(tensor.op.name)

    model.train_and_eval(FLAGS.num_epochs, 80, FLAGS.lrn_rate, train_data,
                         test_data)
コード例 #3
0
def main(_):
  vocab_mgr = dataset.VocabMgr()
  word_embed = vocab_mgr.load_embedding()
  nyt_record = nyt2010.NYT2010CleanedRecordData(None)
  semeval_record = semeval_v2.SemEvalCleanedRecordData(None)

  with tf.Graph().as_default():
    train_iter = semeval_record.train_data(FLAGS.num_epochs, FLAGS.batch_size)
    test_iter = semeval_record.test_data(1, FLAGS.batch_size)
    # unsup_iter = nyt_record.unsup_data(FLAGS.num_epochs, FLAGS.batch_size)
                                          
    model_name = 'cnn-%d-%d' % (FLAGS.word_dim, FLAGS.num_epochs)
    train_data = train_iter.get_next()
    test_data = test_iter.get_next()
    # unsup_data = unsup_iter.get_next()
    unsup_data = None
    m_train, m_valid = cnn_model.build_train_valid_model(
                          model_name, word_embed,
                          train_data, test_data, unsup_data,
                          FLAGS.is_adv, FLAGS.is_test)

    init_op = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())# for file queue
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    for tensor in tf.trainable_variables():
      tf.logging.info(tensor.op.name)
    
    with tf.Session(config=config) as sess:
      sess.run(init_op)
      print('='*80)

      if FLAGS.is_test:
        test(sess, m_valid, test_iter)
      else:
        train_semeval(sess, m_train, m_valid, test_iter)
コード例 #4
0
semeval_text.length_statistics()

# gen vocab
vocab = dataset.Vocab(config.out_dir, config.vocab_file)
vocab.generate_vocab(semeval_text.tokens())

# trim embedding
embed = dataset.Embed(config.out_dir, config.trimmed_embed300_file,
                      config.vocab_file)
google_embed = dataset.Embed(config.pretrain_embed_dir,
                             config.google_embed300_file,
                             config.google_words_file)
embed.trim_pretrain_embedding(google_embed)

# build SemEval record data
semeval_text.set_vocab(vocab)
tag_converter = semeval_v2.TagConverter(config.semeval_dir,
                                        config.semeval_relations_file,
                                        config.semeval_tags_file)
semeval_text.set_tag_converter(tag_converter)
semeval_record = semeval_v2.SemEvalCleanedRecordData(
    semeval_text, config.out_dir, config.semeval_train_record,
    config.semeval_test_record)
semeval_record.generate_data()

# INFO:tensorflow:(percent, quantile) [(50, 18.0), (70, 22.0), (80, 25.0),
#                              (90, 29.0), (95, 34.0), (98, 40.0), (100, 97.0)]
# INFO:tensorflow:generate vocab to data/generated/vocab.txt
# INFO:tensorflow:trim embedding to data/generated/embed300.trim.npy
# INFO:tensorflow:generate TFRecord data
コード例 #5
0
semeval_text = semeval_v2.SemEvalCleanedTextData()
# nyt_text = nyt2010.NYT2010CleanedTextData()

# length statistics
semeval_text.length_statistics()
# nyt_text.length_statistics()

# gen vocab
vocab_mgr = dataset.VocabMgr()
vocab_mgr.generate_vocab(semeval_text.tokens())

# trim embedding
vocab_mgr.trim_pretrain_embedding()

# build SemEval record data
semeval_text.set_vocab_mgr(vocab_mgr)
semeval_record = semeval_v2.SemEvalCleanedRecordData(semeval_text)
semeval_record.generate_data()

# build nyt record data
# nyt_text.set_vocab_mgr(vocab_mgr)
# nyt_record = nyt2010.NYT2010CleanedRecordData(nyt_text)
# nyt_record.generate_data()

# INFO:tensorflow:(percent, quantile) [(50, 17.0), (70, 21.0), (80, 24.0), (90, 29.0), (95, 33.0), (98, 40.0), (100, 98.0)]
# INFO:tensorflow:(percent, quantile) [(50, 39.0), (70, 47.0), (80, 53.0), (90, 62.0), (95, 71.0), (98, 84.0), (100, 9621.0)]
# INFO:tensorflow:generate TFRecord data
# INFO:tensorflow:generate TFRecord data
# INFO:tensorflow:ignore 1361 examples