Exemple #1
0
def train(network='rnn'):
	word2id, id2word = load_data(TOKEN_DATA)
	tag2id, id2tag = load_data(TAG_DATA)
	x_train, y_train, seq_lens, _, _ = generate_data(TRAIN_DATA, word2id, tag2id, max_len=hp.max_len)
	x_dev, y_dev, dev_seq_lens, _, source_tag = generate_data(DEV_DATA, word2id, tag2id, max_len=hp.max_len)
	vocab_size = len(word2id)
	num_tags = len(tag2id)
	if network == "transformer":
		model = TransformerCRFModel(vocab_size, num_tags, is_training=True)
	elif network == 'rnn':
		model = BiRnnCRF(vocab_size, num_tags)
	elif network == 'cnn':
		model = CnnCRF(vocab_size, num_tags)
	elif network == 'match-pyramid':
		model = CnnCRF(vocab_size, num_tags)
	else:
		return
	sv = tf.train.Supervisor(graph=model.graph, logdir=logdir, save_model_secs=0)
	with sv.managed_session() as sess:
		for epoch in range(1, hp.num_epochs + 1):
			if sv.should_stop():
				break
			train_loss = []
			for x_batch, y_batch, len_batch in batch_data(x_train, y_train, seq_lens, hp.batch_size):
				feed_dict = {model.x: x_batch, model.y: y_batch, model.seq_lens: len_batch}
				loss, _ = sess.run([model.loss, model.train_op], feed_dict=feed_dict)
				train_loss.append(loss)
			
			dev_loss = []
			predict_lists = []
			for x_batch, y_batch, len_batch in batch_data(x_dev, y_dev, dev_seq_lens, hp.batch_size):
				feed_dict = {model.x: x_batch, model.y: y_batch, model.seq_lens: len_batch}
				loss, logits = sess.run([model.loss, model.logits], feed_dict)
				dev_loss.append(loss)
				
				transition = model.transition.eval(session=sess)
				pre_seq = model.predict(logits, transition, len_batch)
				pre_label = recover_label(pre_seq, len_batch, id2tag)
				predict_lists.extend(pre_label)
			train_loss_v = np.round(float(np.mean(train_loss)), 4)
			dev_loss_v = np.round(float(np.mean(dev_loss)), 4)
			print('****************************************************')
			acc, p, r, f = get_ner_fmeasure(source_tag, predict_lists)
			print('epoch:\t{}\ttrain loss:\t{}\tdev loss:\t{}'.format(epoch, train_loss_v, dev_loss_v))
			print('acc:\t{}\tp:\t{}\tr:\t{}\tf:\t{}'.format(acc, p, r, f))
			print('****************************************************\n\n')
Exemple #2
0
    def evaluate_batch(self, eva_data):
        wl = self.args.vocab.wl
        cl = self.args.vocab.cl

        batch_size = self.args.batch_size
        ## set model in eval model
        self.model.eval()
        correct_preds = 0.
        total_preds = 0.
        total_correct = 0.
        accs = []
        pred_results = []
        gold_results = []
        for i, (words, label_ids) in enumerate(
                self.args.vocab.minibatches(eva_data, batch_size=batch_size)):
            char_ids, word_ids = zip(*words)
            word_ids, sequence_lengths = seqPAD.pad_sequences(word_ids,
                                                              pad_tok=0,
                                                              wthres=wl,
                                                              cthres=cl)
            char_ids, word_lengths = seqPAD.pad_sequences(char_ids,
                                                          pad_tok=0,
                                                          nlevels=2,
                                                          wthres=wl,
                                                          cthres=cl)
            label_ids, _ = seqPAD.pad_sequences(label_ids,
                                                pad_tok=0,
                                                wthres=wl,
                                                cthres=cl)

            data_tensors = Data2tensor.sort_tensors(label_ids,
                                                    word_ids,
                                                    sequence_lengths,
                                                    char_ids,
                                                    word_lengths,
                                                    volatile_flag=True)
            label_tensor, word_tensor, sequence_lengths, word_seq_recover, char_tensor, word_lengths, char_seq_recover = data_tensors
            mask_tensor = word_tensor > 0

            label_score = self.model(word_tensor, sequence_lengths,
                                     char_tensor, word_lengths,
                                     char_seq_recover)

            label_prob, label_pred = self.model.inference(
                label_score, mask_tensor)

            pred_label, gold_label = recover_label(label_pred, label_tensor,
                                                   mask_tensor,
                                                   self.args.vocab.l2i,
                                                   word_seq_recover)
            pred_results += pred_label
            gold_results += gold_label
        acc, p, r, f = get_ner_fmeasure(gold_results, pred_results)

        #            label_pred = label_pred.cpu().data.numpy()
        #            label_tensor = label_tensor.cpu().data.numpy()
        #            sequence_lengths = sequence_lengths.cpu().data.numpy()
        #
        #            for lab, lab_pred, length in zip(label_tensor, label_pred, sequence_lengths):
        #                lab      = lab[:length]
        #                lab_pred = lab_pred[:length]
        #                accs    += [a==b for (a, b) in zip(lab, lab_pred)]
        #
        #                lab_chunks      = set(NERchunks.get_chunks(lab, self.args.vocab.l2i))
        #                lab_pred_chunks = set(NERchunks.get_chunks(lab_pred, self.args.vocab.l2i))
        #
        #                correct_preds += len(lab_chunks & lab_pred_chunks)
        #                total_preds   += len(lab_pred_chunks)
        #                total_correct += len(lab_chunks)
        #
        #        p   = correct_preds / total_preds if correct_preds > 0 else 0
        #        r   = correct_preds / total_correct if correct_preds > 0 else 0
        #        f  = 2 * p * r / (p + r) if correct_preds > 0 else 0
        #        acc = np.mean(accs)

        return acc, f
def train():
    # load data sets
    train_sentences = load_sentences(FLAGS.train_file, FLAGS.lower,
                                     FLAGS.zeros)
    dev_sentences = load_sentences(FLAGS.dev_file, FLAGS.lower, FLAGS.zeros)
    test_sentences = load_sentences(FLAGS.test_file, FLAGS.lower, FLAGS.zeros)

    # Use selected tagging scheme (IOB / IOBES)
    #update_tag_scheme(train_sentences, FLAGS.tag_schema)
    #update_tag_scheme(test_sentences, FLAGS.tag_schema)

    # create maps if not exist
    if not os.path.isfile(FLAGS.map_file):
        # create dictionary for word
        _c, char_to_id, id_to_char = char_mapping(train_sentences, FLAGS.lower)

        # Create a dictionary and a mapping for tags
        _t, tag_to_id, id_to_tag = tag_mapping(train_sentences)
        os.makedirs('%s' % FLAGS.save_path)
        with open(FLAGS.map_file, "wb") as f:
            pickle.dump([char_to_id, id_to_char, tag_to_id, id_to_tag], f)
    else:
        with open(FLAGS.map_file, "rb") as f:
            char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)

    # prepare data, get a collection of list containing index
    train_data = prepare_padding_dataset(train_sentences, FLAGS.max_seq_len,
                                         char_to_id, tag_to_id, FLAGS.lower)
    dev_data = prepare_padding_dataset(dev_sentences, FLAGS.max_seq_len,
                                       char_to_id, tag_to_id, FLAGS.lower)
    test_data = prepare_padding_dataset(test_sentences, FLAGS.max_seq_len,
                                        char_to_id, tag_to_id, FLAGS.lower)

    print("%i / %i / %i sentences in train / dev / test." %
          (len(train_data), len(dev_data), len(test_data)))

    train_manager = BatchManager(train_data, FLAGS.batch_size)
    dev_manager = BatchManager(dev_data, 100)
    test_manager = BatchManager(test_data, 100)
    """
    batch = train_manager.batch_data[0]
    strings, chars, segs, tags = batch
    for chrs in chars:
        print(chrs)
    for chrs in segs:
        print(chrs)
    print(tag_to_id)
    """
    # make path for store log and model if not exist
    make_path(FLAGS)
    if os.path.isfile(FLAGS.config_file):
        config = load_config(FLAGS.config_file)
    else:
        config = config_model(char_to_id, tag_to_id)
        save_config(config, FLAGS.config_file)
    make_path(FLAGS)

    log_path = os.path.join(FLAGS.save_path, "log", FLAGS.log_file)
    logger = get_logger(log_path)
    print_config(config, logger)

    # limit GPU memory
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    steps_per_epoch = train_manager.len_data
    with tf.Session(config=tf_config) as sess:
        model = TransformerCRFModel(config, is_training=True)
        sess.run(tf.global_variables_initializer())
        logger.info("start training")
        loss = []
        for i in range(100):
            for batch in train_manager.iter_batch(shuffle=True):
                step, batch_loss = model.run_step(sess, True, batch)
                loss.append(batch_loss)
                if step % FLAGS.steps_check == 0:
                    iteration = step // steps_per_epoch + 1
                    logger.info("iteration:{} step:{}/{}, "
                                "NER loss:{:>9.6f}".format(
                                    iteration, step % steps_per_epoch,
                                    steps_per_epoch, np.mean(loss)))
                    loss = []

            predict_lists = []
            source_tag = []
            best_dev_f1 = 0.0
            best_test_f1 = 0.0
            for batch in dev_manager.iter_batch(shuffle=False):
                lengths, logits = model.run_step(sess, False, batch)
                _, chars, segs, tags = batch
                transition = model.transition.eval(session=sess)
                pre_seq = model.predict(logits, transition, lengths)
                pre_label = recover_label(pre_seq, lengths, id_to_tag)
                """
                for p in range(len(pre_label)):
                    print(chars[p])
                    print(pre_label[p])
                """
                source_label = recover_label(tags, lengths, id_to_tag)
                predict_lists.extend(pre_label)
                source_tag.extend(source_label)
            train_loss_v = np.round(float(np.mean(loss)), 4)
            print('****************************************************')
            acc, p, r, f = get_ner_fmeasure(source_tag, predict_lists,
                                            config["tag_schema"])
            logger.info('epoch:\t{}\ttrain loss:\t{}\t'.format(
                i + 1, train_loss_v))
            logger.info('dev acc:\t{}\tp:\t{}\tr:\t{}\tf:\t{}'.format(
                acc, p, r, f))

            for batch in test_manager.iter_batch(shuffle=False):
                lengths, logits = model.run_step(sess, False, batch)
                _, chars, segs, tags = batch
                transition = model.transition.eval(session=sess)
                pre_seq = model.predict(logits, transition, lengths)
                pre_label = recover_label(pre_seq, lengths, id_to_tag)
                source_label = recover_label(tags, lengths, id_to_tag)
                predict_lists.extend(pre_label)
                source_tag.extend(source_label)

            acc_t, p_t, r_t, f_t = get_ner_fmeasure(source_tag, predict_lists,
                                                    config["tag_schema"])
            logger.info('test acc:\t{}\tp:\t{}\tr:\t{}\tf:\t{}'.format(
                acc_t, p_t, r_t, f_t))
            if f > best_dev_f1:
                save_model(sess, model, FLAGS.ckpt_path, logger)
                best_dev_f1 = f
                best_test_f1 = f_t
                logger.info(
                    'save epoch:\t{} model with best dev f1-score'.format(i +
                                                                          1))

            print('****************************************************\n\n')