Esempio n. 1
0
    def test_file(self, file_path):
        self.network.eval()
        tmp_filename = 'tmp/%s_test' % (str(uid))
        self.writer.start(tmp_filename)

        self.data_small_test = conll03_data.read_data_to_variable(
            file_path,
            self.word_alphabet,
            self.char_alphabet,
            self.pos_alphabet,
            self.chunk_alphabet,
            self.ner_alphabet,
            use_gpu=self.use_gpu,
            volatile=True)

        for batch in conll03_data.iterate_batch_variable(
                self.data_small_test, self.batch_size):
            word, char, pos, chunk, labels, masks, lengths = batch
            preds, _ = self.network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            self.writer.write(word.data.cpu().numpy(),
                              pos.data.cpu().numpy(),
                              chunk.data.cpu().numpy(),
                              preds.cpu().numpy(),
                              labels.data.cpu().numpy(),
                              lengths.cpu().numpy())
        self.writer.close()
        test_acc, test_precision, test_recall, test_f1 = evaluate(tmp_filename)
        return test_acc, test_precision, test_recall, test_f1
Esempio n. 2
0
def extract_features(data_name, feat_name):
    with open('temp/' + data_name, 'rb') as f:
        train_data = pickle.load(f)
    print(len(train_data))

    network = torch.load('temp/ner_tuned.pt')
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("ner_alphabet/", None)

    feats = []
    for sent, mention, ont_types, weight in train_data:
        with open('tmp', 'w') as f:
            for i, word in enumerate(sent.words):
                f.write('{0} {1} -- -- O\n'.format(i + 1, word.word))
        sent_data = conll03_data.read_data_to_variable('tmp',
                                                       word_alphabet,
                                                       char_alphabet,
                                                       pos_alphabet,
                                                       chunk_alphabet,
                                                       ner_alphabet,
                                                       use_gpu=False,
                                                       volatile=True)
        os.system('rm tmp')
        word, char, pos, chunk, labels, masks, lengths = conll03_data.iterate_batch_variable(
            sent_data, 1).next()
        feat = network.feature(word,
                               char,
                               target=labels,
                               mask=masks,
                               leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
        print(feat.size())
        feat_vec = feat[0, mention['head_index'], :]
        feats.append((feat_vec.data.numpy(), ont_types, weight))
        print(np.shape(feats[-1][0]))

    with open('temp/' + feat_name, 'wb') as f:
        pickle.dump(feats, f)
Esempio n. 3
0
    def test(self):
        self.network.eval()
        # evaluate on test data when better performance detected
        tmp_filename = 'tmp/%s_test' % (str(uid))
        self.writer.start(tmp_filename)

        for batch in conll03_data.iterate_batch_variable(
                self.data_test, self.batch_size):
            word, char, pos, chunk, labels, masks, lengths = batch
            preds, _ = self.network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            self.writer.write(word.data.cpu().numpy(),
                              pos.data.cpu().numpy(),
                              chunk.data.cpu().numpy(),
                              preds.cpu().numpy(),
                              labels.data.cpu().numpy(),
                              lengths.cpu().numpy())
        self.writer.close()
        test_acc, test_precision, test_recall, test_f1 = evaluate(tmp_filename)
        return test_acc, test_precision, test_recall, test_f1
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_layers',
                        type=int,
                        default=1,
                        help='Number of layers of RNN')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.015,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p_rnn',
                        nargs=2,
                        type=float,
                        required=True,
                        help='dropout rate for RNN')
    parser.add_argument('--p_in',
                        type=float,
                        default=0.33,
                        help='dropout rate for input embeddings')
    parser.add_argument('--p_out',
                        type=float,
                        default=0.33,
                        help='dropout rate for output layer')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("NERCRF")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("data/alphabets/ner_crf/", train_path, data_paths=[dev_path, test_path],
                                                                 embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conll03_data.read_data_to_variable(train_path,
                                                    word_alphabet,
                                                    char_alphabet,
                                                    pos_alphabet,
                                                    chunk_alphabet,
                                                    ner_alphabet,
                                                    use_gpu=use_gpu)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data_to_variable(dev_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  chunk_alphabet,
                                                  ner_alphabet,
                                                  use_gpu=use_gpu,
                                                  volatile=True)
    data_test = conll03_data.read_data_to_variable(test_path,
                                                   word_alphabet,
                                                   char_alphabet,
                                                   pos_alphabet,
                                                   chunk_alphabet,
                                                   ner_alphabet,
                                                   use_gpu=use_gpu,
                                                   volatile=True)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     num_filters,
                                     window,
                                     mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     tag_space=tag_space,
                                     embedd_word=word_table,
                                     p_in=p_in,
                                     p_out=p_out,
                                     p_rnn=p_rnn,
                                     bigram=bigram,
                                     initializer=initializer)
    else:
        network = BiVarRecurrentConvCRF(embedd_dim,
                                        word_alphabet.size(),
                                        char_dim,
                                        char_alphabet.size(),
                                        num_filters,
                                        window,
                                        mode,
                                        hidden_size,
                                        num_layers,
                                        num_labels,
                                        tag_space=tag_space,
                                        embedd_word=word_table,
                                        p_in=p_in,
                                        p_out=p_out,
                                        p_rnn=p_rnn,
                                        bigram=bigram,
                                        initializer=initializer)

    if use_gpu:
        network.cuda()

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma,
                nesterov=True)
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d, crf=%s"
        % (mode, num_layers, hidden_size, num_filters, tag_space,
           'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))

    num_batches = num_data / batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        tmp_filename = 'tmp/%s_dev%d' % (str(uid), epoch)
        writer.start(tmp_filename)

        for batch in conll03_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, chunk, labels, masks, lengths = batch
            preds, _ = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            writer.write(word.data.cpu().numpy(),
                         pos.data.cpu().numpy(),
                         chunk.data.cpu().numpy(),
                         preds.cpu().numpy(),
                         labels.data.cpu().numpy(),
                         lengths.cpu().numpy())
        writer.close()
        acc, precision, recall, f1 = evaluate(tmp_filename)
        print(
            'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' %
            (acc, precision, recall, f1))

        if dev_f1 < f1:
            dev_f1 = f1
            dev_acc = acc
            dev_precision = precision
            dev_recall = recall
            best_epoch = epoch

            # evaluate on test data when better performance detected
            tmp_filename = 'tmp/%s_test%d' % (str(uid), epoch)
            writer.start(tmp_filename)

            for batch in conll03_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, pos, chunk, labels, masks, lengths = batch
                preds, _ = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            test_acc, test_precision, test_recall, test_f1 = evaluate(
                tmp_filename)

        print(
            "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        print(
            "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (test_acc, test_precision, test_recall, test_f1, best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)
Esempio n. 5
0
def regen_train_data(nlp, fpath):
    ontology = OntologyType()
    decisions = ontology.load_decision_tree()
    network = torch.load('temp/ner_tuned.pt')
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("ner_alphabet/", None)

    train_set = set()
    for fname in os.listdir('../../data/txt/'):
        if fname.endswith('.dump'):
            train_set.add(fname[:-5])
    print(train_set)

    train_data = []
    train_feat = []

    for root, dirs, files in os.walk('../../data/ltf/'):
        for file in files:
            if file in train_set:
                print(file)
                sents, doc = read_ltf_offset(os.path.join(root, file))
                for sent in sents:
                    named_ents, ners, feats = extract_ner(sent)
                    for mention, feat in zip(named_ents, feats):
                        prdt_type = infer_type(feat, decisions)
                        coherence = type_coherence(mention['type'], prdt_type,
                                                   ontology)
                        if coherence > 0:
                            train_data.append(
                                (sent, mention,
                                 [prdt_type] + ontology.lookup_all(prdt_type),
                                 coherence))
                            train_feat.append(
                                (feat,
                                 [prdt_type] + ontology.lookup_all(prdt_type),
                                 coherence))
                            print(sent.get_text())
                    nominals = extract_nominals(sent, nlp, ners)
                    for mention in nominals:
                        ont_types = ontology.lookup_all(mention['headword'])
                        if ont_types:
                            train_data.append((sent, mention, ont_types, 1.0))
                            with open('tmp', 'w') as f:
                                for i, word in enumerate(sent.words):
                                    f.write('{0} {1} -- -- O\n'.format(
                                        i + 1, word.word))
                            sent_data = conll03_data.read_data_to_variable(
                                'tmp',
                                word_alphabet,
                                char_alphabet,
                                pos_alphabet,
                                chunk_alphabet,
                                ner_alphabet,
                                use_gpu=False,
                                volatile=True)
                            os.system('rm tmp')
                            word, char, pos, chunk, labels, masks, lengths = conll03_data.iterate_batch_variable(
                                sent_data, 1).next()
                            feat = network.feature(
                                word,
                                char,
                                target=labels,
                                mask=masks,
                                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS
                            )
                            feat_vec = feat[0, mention['head_index'], :]
                            train_feat.append(
                                (feat_vec.data.numpy(), ont_types, 1.0))
                            print(sent.get_text())

    with open('temp/' + fpath + 'data.dump', 'wb') as f:
        pickle.dump(train_data, f)
    with open('temp/' + fpath + 'feat.dump', 'wb') as f:
        pickle.dump(train_feat, f)
def main():
    embedding = 'glove'
    embedding_path = '/media/xianyang/OS/workspace/ner/glove.6B/glove.6B.100d.txt'
    word_alphabet, char_alphabet, pos_alphabet, \
    chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("/media/xianyang/OS/workspace/ner/NeuroNLP2/data/alphabets/ner_crf/", None)
    char_dim = 30
    num_filters = 30
    window = 3
    mode = 'LSTM'
    hidden_size = 256
    num_layers = 1
    num_labels = ner_alphabet.size()
    tag_space = 128
    p = 0.5
    bigram = True
    embedd_dim = 100
    use_gpu = False

    print(len(word_alphabet.get_content()['instances']))
    print(ner_alphabet.get_content())

    # writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet)
    network = BiRecurrentConvCRF(embedd_dim,
                                 word_alphabet.size(),
                                 char_dim,
                                 char_alphabet.size(),
                                 num_filters,
                                 window,
                                 mode,
                                 hidden_size,
                                 num_layers,
                                 num_labels,
                                 tag_space=tag_space,
                                 embedd_word=None,
                                 p_rnn=p,
                                 bigram=bigram)
    network.load_state_dict(torch.load('temp/23df51_model45'))

    ner_alphabet.add('B-VEH')
    ner_alphabet.add('I-VEH')
    ner_alphabet.add('B-WEA')
    ner_alphabet.add('I-WEA')
    num_new_word = 0

    with open('temp/target.train.conll', 'r') as f:
        sents = []
        sent_buffer = []
        for line in f:
            if len(line) <= 1:
                sents.append(sent_buffer)
                sent_buffer = []
            else:
                id, word, _, _, ner = line.strip().split()
                if word_alphabet.get_index(word) == 0:
                    word_alphabet.add(word)
                    num_new_word += 1
                sent_buffer.append((word_alphabet.get_index(word),
                                    ner_alphabet.get_index(ner)))

    print(len(word_alphabet.get_content()['instances']))
    print(ner_alphabet.get_content())

    init_embed = network.word_embedd.weight.data
    init_embed = np.concatenate(
        (init_embed, np.zeros((num_new_word, embedd_dim))), axis=0)
    network.word_embedd = Embedding(word_alphabet.size(), embedd_dim,
                                    torch.from_numpy(init_embed))

    old_crf = network.crf
    new_crf = ChainCRF(tag_space, ner_alphabet.size(), bigram=bigram)
    trans_matrix = np.zeros((new_crf.num_labels, old_crf.num_labels))
    for i in range(old_crf.num_labels):
        trans_matrix[i, i] = 1
    new_crf.state_nn.weight.data = torch.FloatTensor(
        np.dot(trans_matrix, old_crf.state_nn.weight.data))
    network.crf = new_crf

    target_train_data = conll03_data.read_data_to_variable(
        'temp/target.train.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=False)
    target_dev_data = conll03_data.read_data_to_variable(
        'temp/target.dev.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=False)
    target_test_data = conll03_data.read_data_to_variable(
        'temp/target.test.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=False)

    num_epoch = 50
    batch_size = 32
    num_data = sum(target_train_data[1])
    num_batches = num_data / batch_size + 1
    unk_replace = 0.0
    # optim = SGD(network.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0, nesterov=True)
    optim = Adam(network.parameters(), lr=1e-3)

    for epoch in range(1, num_epoch + 1):
        train_err = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()

        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                target_train_data, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            if batch % 20 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d loss: %.4f, time: %.2fs' % (
                    num_batches, train_err / train_total,
                    time.time() - start_time)
                print(log_info)
                num_back = len(log_info)

        writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                               chunk_alphabet, ner_alphabet)
        os.system('rm temp/output.txt')
        writer.start('temp/output.txt')
        network.eval()
        for batch in conll03_data.iterate_batch_variable(
                target_dev_data, batch_size):
            word, char, pos, chunk, labels, masks, lengths, _ = batch
            preds, _, _ = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            writer.write(word.data.cpu().numpy(),
                         pos.data.cpu().numpy(),
                         chunk.data.cpu().numpy(),
                         preds.cpu().numpy(),
                         labels.data.cpu().numpy(),
                         lengths.cpu().numpy())
        writer.close()

        acc, precision, recall, f1 = evaluate('temp/output.txt')
        log_info = 'dev: %f %f %f %f' % (acc, precision, recall, f1)
        print(log_info)

        if epoch % 10 == 0:
            writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                                   chunk_alphabet, ner_alphabet)
            os.system('rm temp/output.txt')
            writer.start('temp/output.txt')
            network.eval()
            for batch in conll03_data.iterate_batch_variable(
                    target_test_data, batch_size):
                word, char, pos, chunk, labels, masks, lengths, _ = batch
                preds, _, _ = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()

            acc, precision, recall, f1 = evaluate('temp/output.txt')
            log_info = 'test: %f %f %f %f' % (acc, precision, recall, f1)
            print(log_info)

    torch.save(network, 'temp/tuned_0905.pt')
    alphabet_directory = '0905_alphabet/'
    word_alphabet.save(alphabet_directory)
    char_alphabet.save(alphabet_directory)
    pos_alphabet.save(alphabet_directory)
    chunk_alphabet.save(alphabet_directory)
    ner_alphabet.save(alphabet_directory)
def sample():
    network = torch.load('temp/ner_active.pt')
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("active_alphabet/", None)

    unannotated_data = conll03_data.read_data_to_variable(
        'temp/unannotated.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=True)

    annotated = set()
    with open('temp/annotated.conll', 'r') as f:
        sent_buffer = []
        for line in f:
            if len(line) > 1:
                _, word, _, _, _ = line.strip().split()
                sent_buffer.append(word)
            else:
                annotated.add(' '.join(sent_buffer))
                sent_buffer = []
    print('total annotated data: {}'.format(len(annotated)))

    uncertain = []
    max_sents = 100
    max_words = 500

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)
    writer.start('temp/output.txt')
    network.eval()
    tiebreaker = count()
    for batch in conll03_data.iterate_batch_variable(unannotated_data, 32):
        word, char, pos, chunk, labels, masks, lengths, raws = batch
        preds, _, confidence = network.decode(
            word,
            char,
            target=labels,
            mask=masks,
            leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
        writer.write(word.data.cpu().numpy(),
                     pos.data.cpu().numpy(),
                     chunk.data.cpu().numpy(),
                     preds.cpu().numpy(),
                     labels.data.cpu().numpy(),
                     lengths.cpu().numpy())
        for _ in range(confidence.size()[0]):
            heapq.heappush(uncertain,
                           (confidence[_].numpy()[0] / lengths[_],
                            tiebreaker.next(), word[_].data.numpy(), raws[_]))
    writer.close()

    cost_sents = 0
    cost_words = 0
    with open('temp/query.conll', 'w') as q:
        while cost_sents < max_sents and cost_words < max_words and uncertain:
            sample = heapq.heappop(uncertain)
            if len(sample[3]) <= 5:
                continue
            # print(sample[0])
            # print([word_alphabet.get_instance(wid) for wid in sample[2]])
            print(sample[3])
            to_write = []
            for word in sample[3]:
                if is_url(word):
                    word = '<_URL>'
                to_write.append(word.encode('ascii', 'ignore'))
            if ' '.join(to_write) in annotated:
                continue
            for wn, word in enumerate(to_write):
                q.write('{0} {1} -- -- O\n'.format(wn + 1, word))
            q.write('\n')
            cost_sents += 1
            cost_words += len(sample[3])
def retrain(train_path, dev_path):
    network = torch.load('temp/ner_tuned.pt')
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("ner_alphabet/", None)

    num_new_word = 0
    with open(train_path, 'r') as f:
        sents = []
        sent_buffer = []
        for line in f:
            if len(line) <= 1:
                sents.append(sent_buffer)
                sent_buffer = []
            else:
                id, word, _, _, ner = line.strip().split()
                if word_alphabet.get_index(word) == 0:
                    word_alphabet.add(word)
                    num_new_word += 1
                sent_buffer.append((word_alphabet.get_index(word),
                                    ner_alphabet.get_index(ner)))
    print('{} new words.'.format(num_new_word))
    init_embed = network.word_embedd.weight.data
    embedd_dim = init_embed.shape[1]
    init_embed = np.concatenate(
        (init_embed, np.zeros((num_new_word, embedd_dim))), axis=0)
    network.word_embedd = Embedding(word_alphabet.size(), embedd_dim,
                                    torch.from_numpy(init_embed))

    target_train_data = conll03_data.read_data_to_variable(train_path,
                                                           word_alphabet,
                                                           char_alphabet,
                                                           pos_alphabet,
                                                           chunk_alphabet,
                                                           ner_alphabet,
                                                           use_gpu=False,
                                                           volatile=False)

    num_epoch = 50
    batch_size = 20
    num_data = sum(target_train_data[1])
    num_batches = num_data / batch_size + 1
    unk_replace = 0.0
    optim = SGD(network.parameters(),
                lr=0.01,
                momentum=0.9,
                weight_decay=0.0,
                nesterov=True)

    for epoch in range(num_epoch):
        train_err = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()

        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                target_train_data, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            print('train: %d loss: %.4f, time: %.2fs' %
                  (num_batches, train_err / train_total,
                   time.time() - start_time))

    torch.save(network, 'temp/ner_active.pt')
    alphabet_directory = 'active_alphabet/'
    word_alphabet.save(alphabet_directory)
    char_alphabet.save(alphabet_directory)
    pos_alphabet.save(alphabet_directory)
    chunk_alphabet.save(alphabet_directory)
    ner_alphabet.save(alphabet_directory)

    target_dev_data = conll03_data.read_data_to_variable(dev_path,
                                                         word_alphabet,
                                                         char_alphabet,
                                                         pos_alphabet,
                                                         chunk_alphabet,
                                                         ner_alphabet,
                                                         use_gpu=False,
                                                         volatile=False)
    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)
    os.system('rm output.txt')
    writer.start('output.txt')
    network.eval()
    for batch in conll03_data.iterate_batch_variable(target_dev_data,
                                                     batch_size):
        word, char, pos, chunk, labels, masks, lengths, _ = batch
        preds, _, _ = network.decode(
            word,
            char,
            target=labels,
            mask=masks,
            leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
        writer.write(word.data.cpu().numpy(),
                     pos.data.cpu().numpy(),
                     chunk.data.cpu().numpy(),
                     preds.cpu().numpy(),
                     labels.data.cpu().numpy(),
                     lengths.cpu().numpy())
    writer.close()

    acc, precision, recall, f1 = evaluate('output.txt')
    print(acc, precision, recall, f1)
    return acc, precision, recall, f1