Example #1
0
    def __init__(self, train, test, embeddings_filename, batch_size=1):
        self.train_path = train
        self.test_path = test
        self.mode = 'LSTM'
        self.dropout = 'std'
        self.num_epochs = 1
        self.batch_size = batch_size
        self.hidden_size = 256
        self.num_filters = 30
        self.learning_rate = 0.01
        self.momentum = 0.9
        self.decay_rate = 0.05
        self.gamma = 0.0
        self.schedule = 1
        self.p_rnn = tuple([0.33, 0.5])
        self.p_in = 0.33
        self.p_out = 0.5
        self.unk_replace = 0.0
        self.bigram = True
        self.embedding = 'glove'
        self.logger = get_logger("NERCRF")
        self.char_dim = 30
        self.window = 3
        self.num_layers = 1
        self.tag_space = 128
        self.initializer = nn.init.xavier_uniform

        self.use_gpu = torch.cuda.is_available()

        self.embedd_dict, self.embedd_dim = utils.load_embedding_dict(
            self.embedding, embeddings_filename)
        self.word_alphabet, self.char_alphabet, self.pos_alphabet, \
        self.chunk_alphabet, self.ner_alphabet = conll03_data.create_alphabets("data/alphabets/ner_crf/", self.train_path, data_paths=[self.test_path],
        embedd_dict=self.embedd_dict, max_vocabulary_size=50000)
        self.word_table = self.construct_word_embedding_table()

        self.logger.info("Word Alphabet Size: %d" % self.word_alphabet.size())
        self.logger.info("Character Alphabet Size: %d" %
                         self.char_alphabet.size())
        self.logger.info("POS Alphabet Size: %d" % self.pos_alphabet.size())
        self.logger.info("Chunk Alphabet Size: %d" %
                         self.chunk_alphabet.size())
        self.logger.info("NER Alphabet Size: %d" % self.ner_alphabet.size())
        self.num_labels = self.ner_alphabet.size()

        self.data_test = conll03_data.read_data_to_variable(
            self.test_path,
            self.word_alphabet,
            self.char_alphabet,
            self.pos_alphabet,
            self.chunk_alphabet,
            self.ner_alphabet,
            use_gpu=self.use_gpu,
            volatile=True)
        self.writer = CoNLL03Writer(self.word_alphabet, self.char_alphabet,
                                    self.pos_alphabet, self.chunk_alphabet,
                                    self.ner_alphabet)
Example #2
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_layers',
                        type=int,
                        default=1,
                        help='Number of layers of RNN')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.015,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p_rnn',
                        nargs=2,
                        type=float,
                        required=True,
                        help='dropout rate for RNN')
    parser.add_argument('--p_in',
                        type=float,
                        default=0.33,
                        help='dropout rate for input embeddings')
    parser.add_argument('--p_out',
                        type=float,
                        default=0.33,
                        help='dropout rate for output layer')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("NERCRF")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("data/alphabets/ner_crf/", train_path, data_paths=[dev_path, test_path],
                                                                 embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conll03_data.read_data_to_variable(train_path,
                                                    word_alphabet,
                                                    char_alphabet,
                                                    pos_alphabet,
                                                    chunk_alphabet,
                                                    ner_alphabet,
                                                    use_gpu=use_gpu)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data_to_variable(dev_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  chunk_alphabet,
                                                  ner_alphabet,
                                                  use_gpu=use_gpu,
                                                  volatile=True)
    data_test = conll03_data.read_data_to_variable(test_path,
                                                   word_alphabet,
                                                   char_alphabet,
                                                   pos_alphabet,
                                                   chunk_alphabet,
                                                   ner_alphabet,
                                                   use_gpu=use_gpu,
                                                   volatile=True)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     num_filters,
                                     window,
                                     mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     tag_space=tag_space,
                                     embedd_word=word_table,
                                     p_in=p_in,
                                     p_out=p_out,
                                     p_rnn=p_rnn,
                                     bigram=bigram,
                                     initializer=initializer)
    else:
        network = BiVarRecurrentConvCRF(embedd_dim,
                                        word_alphabet.size(),
                                        char_dim,
                                        char_alphabet.size(),
                                        num_filters,
                                        window,
                                        mode,
                                        hidden_size,
                                        num_layers,
                                        num_labels,
                                        tag_space=tag_space,
                                        embedd_word=word_table,
                                        p_in=p_in,
                                        p_out=p_out,
                                        p_rnn=p_rnn,
                                        bigram=bigram,
                                        initializer=initializer)

    if use_gpu:
        network.cuda()

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma,
                nesterov=True)
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d, crf=%s"
        % (mode, num_layers, hidden_size, num_filters, tag_space,
           'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))

    num_batches = num_data / batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        tmp_filename = 'tmp/%s_dev%d' % (str(uid), epoch)
        writer.start(tmp_filename)

        for batch in conll03_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, chunk, labels, masks, lengths = batch
            preds, _ = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            writer.write(word.data.cpu().numpy(),
                         pos.data.cpu().numpy(),
                         chunk.data.cpu().numpy(),
                         preds.cpu().numpy(),
                         labels.data.cpu().numpy(),
                         lengths.cpu().numpy())
        writer.close()
        acc, precision, recall, f1 = evaluate(tmp_filename)
        print(
            'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' %
            (acc, precision, recall, f1))

        if dev_f1 < f1:
            dev_f1 = f1
            dev_acc = acc
            dev_precision = precision
            dev_recall = recall
            best_epoch = epoch

            # evaluate on test data when better performance detected
            tmp_filename = 'tmp/%s_test%d' % (str(uid), epoch)
            writer.start(tmp_filename)

            for batch in conll03_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, pos, chunk, labels, masks, lengths = batch
                preds, _ = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            test_acc, test_precision, test_recall, test_f1 = evaluate(
                tmp_filename)

        print(
            "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        print(
            "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (test_acc, test_precision, test_recall, test_f1, best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)
Example #3
0
def main():
    parser = argparse.ArgumentParser(
        description='NER with bi-directional RNN-CNN')
    parser.add_argument('--config',
                        type=str,
                        help='config file',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--loss_type',
                        choices=['sentence', 'token'],
                        default='sentence',
                        help='loss type (default: sentence)')
    parser.add_argument('--optim',
                        choices=['sgd', 'adam'],
                        help='type of optimizer',
                        required=True)
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.1,
                        help='Learning rate')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.999995,
                        help='Decay rate of learning rate')
    parser.add_argument('--amsgrad', action='store_true', help='AMS Grad')
    parser.add_argument('--grad_clip',
                        type=float,
                        default=0,
                        help='max norm for gradient clip (default 0: no clip')
    parser.add_argument('--warmup_steps',
                        type=int,
                        default=0,
                        metavar='N',
                        help='number of steps to warm up (default: 0)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0,
                        help='weight for l2 norm decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--train',
                        help='path for training file.',
                        required=True)
    parser.add_argument('--dev', help='path for dev file.', required=True)
    parser.add_argument('--test', help='path for test file.', required=True)
    parser.add_argument('--model_path',
                        help='path for saving model file.',
                        required=True)

    args = parser.parse_args()

    logger = get_logger("NER")

    args.cuda = torch.cuda.is_available()
    device = torch.device('cuda', 0) if args.cuda else torch.device('cpu')
    train_path = args.train
    dev_path = args.dev
    test_path = args.test

    num_epochs = args.num_epochs
    batch_size = args.batch_size
    optim = args.optim
    learning_rate = args.learning_rate
    lr_decay = args.lr_decay
    amsgrad = args.amsgrad
    warmup_steps = args.warmup_steps
    weight_decay = args.weight_decay
    grad_clip = args.grad_clip

    loss_ty_token = args.loss_type == 'token'
    unk_replace = args.unk_replace

    model_path = args.model_path
    model_name = os.path.join(model_path, 'model.pt')
    embedding = args.embedding
    embedding_path = args.embedding_dict

    print(args)

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets')
    word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet = conll03_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict,
        max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")

    data_train = conll03_data.read_bucketed_data(train_path, word_alphabet,
                                                 char_alphabet, pos_alphabet,
                                                 chunk_alphabet, ner_alphabet)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data(dev_path, word_alphabet, char_alphabet,
                                      pos_alphabet, chunk_alphabet,
                                      ner_alphabet)
    data_test = conll03_data.read_data(test_path, word_alphabet, char_alphabet,
                                       pos_alphabet, chunk_alphabet,
                                       ner_alphabet)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()

    logger.info("constructing network...")

    hyps = json.load(open(args.config, 'r'))
    json.dump(hyps,
              open(os.path.join(model_path, 'config.json'), 'w'),
              indent=2)
    dropout = hyps['dropout']
    crf = hyps['crf']
    bigram = hyps['bigram']
    assert embedd_dim == hyps['embedd_dim']
    char_dim = hyps['char_dim']
    mode = hyps['rnn_mode']
    hidden_size = hyps['hidden_size']
    out_features = hyps['out_features']
    num_layers = hyps['num_layers']
    p_in = hyps['p_in']
    p_out = hyps['p_out']
    p_rnn = hyps['p_rnn']
    activation = hyps['activation']

    if dropout == 'std':
        if crf:
            network = BiRecurrentConvCRF(embedd_dim,
                                         word_alphabet.size(),
                                         char_dim,
                                         char_alphabet.size(),
                                         mode,
                                         hidden_size,
                                         out_features,
                                         num_layers,
                                         num_labels,
                                         embedd_word=word_table,
                                         p_in=p_in,
                                         p_out=p_out,
                                         p_rnn=p_rnn,
                                         bigram=bigram,
                                         activation=activation)
        else:
            network = BiRecurrentConv(embedd_dim,
                                      word_alphabet.size(),
                                      char_dim,
                                      char_alphabet.size(),
                                      mode,
                                      hidden_size,
                                      out_features,
                                      num_layers,
                                      num_labels,
                                      embedd_word=word_table,
                                      p_in=p_in,
                                      p_out=p_out,
                                      p_rnn=p_rnn,
                                      activation=activation)
    elif dropout == 'variational':
        if crf:
            network = BiVarRecurrentConvCRF(embedd_dim,
                                            word_alphabet.size(),
                                            char_dim,
                                            char_alphabet.size(),
                                            mode,
                                            hidden_size,
                                            out_features,
                                            num_layers,
                                            num_labels,
                                            embedd_word=word_table,
                                            p_in=p_in,
                                            p_out=p_out,
                                            p_rnn=p_rnn,
                                            bigram=bigram,
                                            activation=activation)
        else:
            network = BiVarRecurrentConv(embedd_dim,
                                         word_alphabet.size(),
                                         char_dim,
                                         char_alphabet.size(),
                                         mode,
                                         hidden_size,
                                         out_features,
                                         num_layers,
                                         num_labels,
                                         embedd_word=word_table,
                                         p_in=p_in,
                                         p_out=p_out,
                                         p_rnn=p_rnn,
                                         activation=activation)
    else:
        raise ValueError('Unkown dropout type: {}'.format(dropout))

    network = network.to(device)

    optimizer, scheduler = get_optimizer(network.parameters(), optim,
                                         learning_rate, lr_decay, amsgrad,
                                         weight_decay, warmup_steps)
    model = "{}-CNN{}".format(mode, "-CRF" if crf else "")
    logger.info("Network: %s, num_layer=%d, hidden=%d, act=%s" %
                (model, num_layers, hidden_size, activation))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (weight_decay, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): %s(%.2f, %.2f, %s)" %
                (dropout, p_in, p_out, p_rnn))
    print('# of Parameters: %d' %
          (sum([param.numel() for param in network.parameters()])))

    best_f1 = 0.0
    best_acc = 0.0
    best_precision = 0.0
    best_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    patient = 0
    num_batches = num_data // batch_size + 1
    result_path = os.path.join(model_path, 'tmp')
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_loss = 0.
        num_insts = 0
        num_words = 0
        num_back = 0
        network.train()
        lr = scheduler.get_lr()[0]
        print('Epoch %d (%s, lr=%.6f, lr decay=%.6f, amsgrad=%s, l2=%.1e): ' %
              (epoch, optim, lr, lr_decay, amsgrad, weight_decay))
        if args.cuda:
            torch.cuda.empty_cache()
        gc.collect()
        for step, data in enumerate(
                iterate_data(data_train,
                             batch_size,
                             bucketed=True,
                             unk_replace=unk_replace,
                             shuffle=True)):
            optimizer.zero_grad()
            words = data['WORD'].to(device)
            chars = data['CHAR'].to(device)
            labels = data['NER'].to(device)
            masks = data['MASK'].to(device)

            nbatch = words.size(0)
            nwords = masks.sum().item()

            loss_total = network.loss(words, chars, labels, mask=masks).sum()
            if loss_ty_token:
                loss = loss_total.div(nwords)
            else:
                loss = loss_total.div(nbatch)
            loss.backward()
            if grad_clip > 0:
                clip_grad_norm_(network.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()

            with torch.no_grad():
                num_insts += nbatch
                num_words += nwords
                train_loss += loss_total.item()

            # update log
            if step % 100 == 0:
                torch.cuda.empty_cache()
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                curr_lr = scheduler.get_lr()[0]
                log_info = '[%d/%d (%.0f%%) lr=%.6f] loss: %.4f (%.4f)' % (
                    step, num_batches, 100. * step / num_batches, curr_lr,
                    train_loss / num_insts, train_loss / num_words)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('total: %d (%d), loss: %.4f (%.4f), time: %.2fs' %
              (num_insts, num_words, train_loss / num_insts,
               train_loss / num_words, time.time() - start_time))
        print('-' * 100)

        # evaluate performance on dev data
        with torch.no_grad():
            outfile = os.path.join(result_path, 'pred_dev%d' % epoch)
            scorefile = os.path.join(result_path, "score_dev%d" % epoch)
            acc, precision, recall, f1 = eval(data_dev, network, writer,
                                              outfile, scorefile, device)
            print(
                'Dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                % (acc, precision, recall, f1))
            if best_f1 < f1:
                torch.save(network.state_dict(), model_name)
                best_f1 = f1
                best_acc = acc
                best_precision = precision
                best_recall = recall
                best_epoch = epoch

                # evaluate on test data when better performance detected
                outfile = os.path.join(result_path, 'pred_test%d' % epoch)
                scorefile = os.path.join(result_path, "score_test%d" % epoch)
                test_acc, test_precision, test_recall, test_f1 = eval(
                    data_test, network, writer, outfile, scorefile, device)
                print(
                    'test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                    % (test_acc, test_precision, test_recall, test_f1))
                patient = 0
            else:
                patient += 1
            print('-' * 100)

            print(
                "Best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d (%d))"
                % (best_acc, best_precision, best_recall, best_f1, best_epoch,
                   patient))
            print(
                "Best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d (%d))"
                % (test_acc, test_precision, test_recall, test_f1, best_epoch,
                   patient))
            print('=' * 100)

        if patient > 4:
            logger.info('reset optimizer momentums')
            scheduler.reset_state()
            patient = 0
def main():
    # Arguments parser
    parser = argparse.ArgumentParser(description='Tuning with DNN Model for NER')
    # Model Hyperparameters
    parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn', default='LSTM')
    parser.add_argument('--encoder_mode', choices=['cnn', 'lstm'], help='Encoder type for sentence encoding',
                        default='lstm')
    parser.add_argument('--char_method', choices=['cnn', 'lstm'], help='Method to create character-level embeddings',
                        required=True)
    parser.add_argument('--hidden_size', type=int, default=128, help='Number of hidden units in RNN for sentence level')
    parser.add_argument('--char_hidden_size', type=int, default=30, help='Output character-level embeddings size')
    parser.add_argument('--char_dim', type=int, default=30, help='Dimension of Character embeddings')
    parser.add_argument('--tag_space', type=int, default=0, help='Dimension of tag space')
    parser.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
    parser.add_argument('--dropout', choices=['std', 'gcn'], help='Dropout method',
                        default='gcn')
    parser.add_argument('--p_em', type=float, default=0.33, help='dropout rate for input embeddings')
    parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input of RNN model')
    parser.add_argument('--p_rnn', nargs=3, type=float, required=True, help='dropout rate for RNN')
    parser.add_argument('--p_tag', type=float, default=0.33, help='dropout rate for output layer')
    parser.add_argument('--bigram', action='store_true', help='bi-gram parameter for CRF')

    parser.add_argument('--adj_attn', choices=['cossim', 'flex_cossim', 'flex_cossim2', 'concat', '', 'multihead'],
                        default='')

    # Data loading and storing params
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--dataset_name', type=str, default='alexa', help='Which dataset to use')
    parser.add_argument('--train', type=str, required=True, help='Path of train set')
    parser.add_argument('--dev', type=str, required=True, help='Path of dev set')
    parser.add_argument('--test', type=str, required=True, help='Path of test set')
    parser.add_argument('--results_folder', type=str, default='results', help='The folder to store results')
    parser.add_argument('--alphabets_folder', type=str, default='data/alphabets',
                        help='The folder to store alphabets files')

    # Training parameters
    parser.add_argument('--cuda', action='store_true', help='whether using GPU')
    parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Number of sentences in each batch')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Base learning rate')
    parser.add_argument('--decay_rate', type=float, default=0.95, help='Decay rate of learning rate')
    parser.add_argument('--schedule', type=int, default=3, help='schedule for learning rate decay')
    parser.add_argument('--gamma', type=float, default=0.0, help='weight for l2 regularization')
    parser.add_argument('--max_norm', type=float, default=1., help='Max norm for gradients')
    parser.add_argument('--gpu_id', type=int, nargs='+', required=True, help='which gpu to use for training')

    parser.add_argument('--learning_rate_gcn', type=float, default=5e-4, help='Base learning rate')
    parser.add_argument('--gcn_warmup', type=int, default=200, help='Base learning rate')
    parser.add_argument('--pretrain_lstm', type=float, default=10, help='Base learning rate')

    parser.add_argument('--adj_loss_lambda', type=float, default=0.)
    parser.add_argument('--lambda1', type=float, default=1.)
    parser.add_argument('--lambda2', type=float, default=0.)
    parser.add_argument('--seed', type=int, default=None)

    # Misc
    parser.add_argument('--embedding', choices=['glove', 'senna', 'alexa'], help='Embedding for words', required=True)
    parser.add_argument('--restore', action='store_true', help='whether restore from stored parameters')
    parser.add_argument('--save_checkpoint', type=str, default='', help='the path to save the model')
    parser.add_argument('--o_tag', type=str, default='O', help='The default tag for outside tag')
    parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    parser.add_argument('--evaluate_raw_format', action='store_true', help='The tagging format for evaluation')


    parser.add_argument('--eval_type', type=str, default="micro_f1",choices=['micro_f1', 'acc'])
    parser.add_argument('--show_network', action='store_true', help='whether to display the network structure')
    parser.add_argument('--smooth', action='store_true', help='whether to skip all pdb break points')

    parser.add_argument('--uid', type=str, default='temp')
    parser.add_argument('--misc', type=str, default='')

    args = parser.parse_args()
    show_var(['args'])

    uid = args.uid
    results_folder = args.results_folder
    dataset_name = args.dataset_name
    use_tensorboard = True

    save_dset_dir = '{}../dset/{}/graph'.format(results_folder, dataset_name)
    result_file_path = '{}/{dataset}_{uid}_result'.format(results_folder, dataset=dataset_name, uid=uid)

    save_loss_path = '{}/{dataset}_{uid}_loss'.format(results_folder, dataset=dataset_name, uid=uid)
    save_lr_path = '{}/{dataset}_{uid}_lr'.format(results_folder, dataset=dataset_name, uid='temp')
    save_tb_path = '{}/tensorboard/'.format(results_folder)

    logger = get_logger("NERCRF")
    loss_recorder = LossRecorder(uid=uid)
    record = TensorboardLossRecord(use_tensorboard, save_tb_path, uid=uid)

    # rename the parameters
    mode = args.mode
    encoder_mode = args.encoder_mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    char_hidden_size = args.char_hidden_size
    char_method = args.char_method
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    max_norm = args.max_norm
    schedule = args.schedule
    dropout = args.dropout
    p_em = args.p_em
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_tag = args.p_tag
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict
    evaluate_raw_format = args.evaluate_raw_format
    o_tag = args.o_tag
    restore = args.restore
    save_checkpoint = args.save_checkpoint
    alphabets_folder = args.alphabets_folder
    use_elmo = False
    p_em_vec = 0.
    graph_model = 'gnn'
    coref_edge_filt = ''

    learning_rate_gcn = args.learning_rate_gcn
    gcn_warmup = args.gcn_warmup
    pretrain_lstm = args.pretrain_lstm

    adj_loss_lambda = args.adj_loss_lambda
    lambda1 = args.lambda1
    lambda2 = args.lambda2

    if args.smooth:
        import pdb
        pdb.set_trace = lambda: None

    misc = "{}".format(str(args.misc))

    score_file = "{}/{dataset}_{uid}_score".format(results_folder, dataset=dataset_name, uid=uid)

    for folder in [results_folder, alphabets_folder, save_dset_dir]:
        if not os.path.exists(folder):
            os.makedirs(folder)

    def set_seed(seed):
        if not seed:
            seed = int(show_time())
        print("[Info] seed set to: {}".format(seed))
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    set_seed(args.seed)

    embedd_dict, embedd_dim = utils.load_embedding_dict(embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, ner_alphabet = conll03_data.create_alphabets(
        "{}/{}/".format(alphabets_folder, dataset_name), train_path, data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')
    print(device)

    data_train = conll03_data.read_data(train_path, word_alphabet, char_alphabet,
                                        ner_alphabet,
                                        graph_model, batch_size, ori_order=False,
                                        total_batch="{}x".format(num_epochs + 1),
                                        unk_replace=unk_replace, device=device,
                                        save_path=save_dset_dir + '/train', coref_edge_filt=coref_edge_filt
                                        )
    # , shuffle=True,
    num_data = data_train.data_len
    num_labels = ner_alphabet.size()
    graph_types = data_train.meta_info['graph_types']

    data_dev = conll03_data.read_data(dev_path, word_alphabet, char_alphabet,
                                      ner_alphabet,
                                      graph_model, batch_size, ori_order=True, unk_replace=unk_replace, device=device,
                                      save_path=save_dset_dir + '/dev',
                                      coref_edge_filt=coref_edge_filt)

    data_test = conll03_data.read_data(test_path, word_alphabet, char_alphabet,
                                       ner_alphabet,
                                       graph_model, batch_size, ori_order=True, unk_replace=unk_replace, device=device,
                                       save_path=save_dset_dir + '/test',
                                       coref_edge_filt=coref_edge_filt)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform_

    p_gcn = [0.5, 0.5]

    d_graph = 256
    d_out = 256
    d_inner_hid = 128
    d_k = 32
    d_v = 32
    n_head = 4
    n_gcn_layer = 1

    p_rnn2 = [0.0, 0.5, 0.5]

    adj_attn = args.adj_attn
    mask_singles = True
    post_lstm = 1
    position_enc_mode = 'none'

    adj_memory = False

    if dropout == 'gcn':
        network = BiRecurrentConvGraphCRF(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(),
                                          char_hidden_size, window, mode, encoder_mode, hidden_size, num_layers,
                                          num_labels,
                                          graph_model, n_head, d_graph, d_inner_hid, d_k, d_v, p_gcn, n_gcn_layer,
                                          d_out, post_lstm=post_lstm, mask_singles=mask_singles,
                                          position_enc_mode=position_enc_mode, adj_attn=adj_attn,
                                          adj_loss_lambda=adj_loss_lambda,
                                          tag_space=tag_space, embedd_word=word_table,
                                          use_elmo=use_elmo, p_em_vec=p_em_vec, p_em=p_em, p_in=p_in, p_tag=p_tag,
                                          p_rnn=p_rnn, p_rnn2=p_rnn2,
                                          bigram=bigram, initializer=initializer)

    elif dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), char_hidden_size,
                                     window, mode, encoder_mode, hidden_size, num_layers, num_labels,
                                     tag_space=tag_space, embedd_word=word_table, use_elmo=use_elmo, p_em_vec=p_em_vec,
                                     p_em=p_em, p_in=p_in, p_tag=p_tag, p_rnn=p_rnn, bigram=bigram,
                                     initializer=initializer)

    # whether restore from trained model
    if restore:
        network.load_state_dict(torch.load(save_checkpoint + '_best.pth'))  # load trained model

    logger.info("cuda()ing network...")

    network = network.to(device)

    if dataset_name == 'conll03' and data_dev.data_len > 26:
        sample = data_dev.pad_batch(data_dev.dataset[25:26])
    else:
        sample = data_dev.pad_batch(data_dev.dataset[:1])
    plot_att_change(sample, network, record, save_tb_path + 'att/', uid='temp', epoch=0, device=device,
                    word_alphabet=word_alphabet, show_net=args.show_network,
                    graph_types=data_train.meta_info['graph_types'])

    logger.info("finished cuda()ing network...")

    lr = learning_rate
    lr_gcn = learning_rate_gcn
    optim = Optimizer('sgd', 'adam', network, dropout, lr=learning_rate,
                      lr_gcn=learning_rate_gcn,
                      wd=0., wd_gcn=0., momentum=momentum, lr_decay=decay_rate, schedule=schedule,
                      gcn_warmup=gcn_warmup,
                      pretrain_lstm=pretrain_lstm)
    nn.utils.clip_grad_norm_(network.parameters(), max_norm)
    logger.info(
        "Network: %s, encoder_mode=%s, num_layer=%d, hidden=%d, char_hidden_size=%d, char_method=%s, tag_space=%d, crf=%s" % \
        (mode, encoder_mode, num_layers, hidden_size, char_hidden_size, char_method, tag_space,
         'bigram' if bigram else 'unigram'))
    logger.info("training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)" % (
        gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_tag, p_rnn))

    num_batches = num_data // batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    best_test_f1 = 0.0
    best_test_acc = 0.0
    best_test_precision = 0.0
    best_test_recall = 0.0
    best_test_epoch = 0.0

    loss_recorder.start(save_loss_path, mode='w', misc=misc)
    fwrite('', save_lr_path)
    fwrite(json.dumps(vars(args)) + '\n', result_file_path)

    for epoch in range(1, num_epochs + 1):
        show_var(['misc'])

        lr_state = 'Epoch %d (uid=%s, lr=%.2E, lr_gcn=%.2E, decay rate=%.4f): ' % (
            epoch, uid, Decimal(optim.curr_lr), Decimal(optim.curr_lr_gcn), decay_rate)
        print(lr_state)
        fwrite(lr_state[:-2] + '\n', save_lr_path, mode='a')

        train_err = 0.
        train_err2 = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch_i in range(1, num_batches + 1):

            batch_doc = data_train.next()
            char, word, posi, labels, feats, adjs, words_en = [batch_doc[i] for i in [
                "chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs", "words_en"]]

            sent_word, sent_char, sent_labels, sent_mask, sent_length, _ = network._doc2sent(
                word, char, labels)

            optim.zero_grad()

            adjs_into_model = adjs if adj_memory else adjs.clone()

            loss, (ner_loss, adj_loss) = network.loss(None, word, char, adjs_into_model, labels,
                                                      graph_types=graph_types, lambda1=lambda1, lambda2=lambda2)

            # loss = network.loss(_, sent_word, sent_char, sent_labels, mask=sent_mask)
            loss.backward()
            optim.step()

            with torch.no_grad():
                num_inst = sent_mask.size(0)
                train_err += ner_loss * num_inst
                train_err2 += adj_loss * num_inst
                train_total += num_inst

            time_ave = (time.time() - start_time) / batch_i
            time_left = (num_batches - batch_i) * time_ave

            # update log
            if batch_i % 20 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss1: %.4f, loss2: %.4f, time left (estimated): %.2fs' % (
                    batch_i, num_batches, train_err / train_total, train_err2 / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

            optim.update(epoch, batch_i, num_batches, network)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, loss2: %.4f, time: %.2fs' % (
            num_batches, train_err / train_total, train_err2 / train_total, time.time() - start_time))

        # evaluate performance on dev data
        with torch.no_grad():
            network.eval()
            tmp_filename = "{}/{dataset}_{uid}_output_dev".format(results_folder, dataset=dataset_name, uid=uid)

            writer.start(tmp_filename)

            for batch in data_dev:
                char, word, posi, labels, feats, adjs, words_en = [batch[i] for i in [
                    "chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs", "words_en"]]
                sent_word, sent_char, sent_labels, sent_mask, sent_length, _ = network._doc2sent(
                    word, char, labels)

                preds, _ = network.decode(
                    None, word, char, adjs.clone(), target=labels, leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS,
                    graph_types=graph_types)
                # preds, _ = network.decode(_, sent_word, sent_char, target=sent_labels, mask=sent_mask,
                #                           leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(sent_word.cpu().numpy(), preds.cpu().numpy(), sent_labels.cpu().numpy(),
                             sent_length.cpu().numpy())
            writer.close()


            if args.eval_type == "acc":
                acc, precision, recall, f1 =evaluate_tokenacc(tmp_filename)
                f1 = acc
            else:
                acc, precision, recall, f1 = evaluate(tmp_filename, score_file, evaluate_raw_format, o_tag)

            print('dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' % (acc, precision, recall, f1))

            # plot loss and attention
            record.plot_loss(epoch, train_err / train_total, f1)

            plot_att_change(sample, network, record, save_tb_path + 'att/', uid="{}_{:03d}".format(uid, epoch),
                            epoch=epoch, device=device,
                            word_alphabet=word_alphabet, show_net=False, graph_types=graph_types)

            if dev_f1 < f1:
                dev_f1 = f1
                dev_acc = acc
                dev_precision = precision
                dev_recall = recall
                best_epoch = epoch

                # evaluate on test data when better performance detected
                tmp_filename = "{}/{dataset}_{uid}_output_test".format(results_folder, dataset=dataset_name,
                                                                       uid=uid)
                writer.start(tmp_filename)

                for batch in data_test:
                    char, word, posi, labels, feats, adjs, words_en = [batch[i] for i in [
                        "chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs", "words_en"]]
                    sent_word, sent_char, sent_labels, sent_mask, sent_length, _ = network._doc2sent(
                        word, char, labels)

                    preds, _ = network.decode(
                        None, word, char, adjs.clone(), target=labels, leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS,
                        graph_types=graph_types)
                    # preds, _ = network.decode(_, sent_word, sent_char, target=sent_labels, mask=sent_mask,
                    #                           leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)

                    writer.write(sent_word.cpu().numpy(), preds.cpu().numpy(), sent_labels.cpu().numpy(),
                                 sent_length.cpu().numpy())
                writer.close()

                if args.eval_type == "acc":
                    test_acc, test_precision, test_recall, test_f1 = evaluate_tokenacc(tmp_filename)
                    test_f1 = test_acc
                else:
                    test_acc, test_precision, test_recall, test_f1 = evaluate(tmp_filename, score_file, evaluate_raw_format,																		  o_tag)

                if best_test_f1 < test_f1:
                    best_test_acc, best_test_precision, best_test_recall, best_test_f1 = test_acc, test_precision, test_recall, test_f1
                    best_test_epoch = epoch

                # save the model parameters
                if save_checkpoint:
                    torch.save(network.state_dict(), save_checkpoint + '_best.pth')

            print("best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
                dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
            print("best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
                test_acc, test_precision, test_recall, test_f1, best_epoch))
            print("overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
                best_test_acc, best_test_precision, best_test_recall, best_test_f1, best_test_epoch))

        # optim.update(epoch, 1, num_batches, network)
        loss_recorder.write(epoch, train_err / train_total, train_err2 / train_total,
                            Decimal(optim.curr_lr), Decimal(optim.curr_lr_gcn), f1, best_test_f1, test_f1)
    with open(result_file_path, 'a') as ofile:
        ofile.write("best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n" % (
            dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        ofile.write("best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n" % (
            test_acc, test_precision, test_recall, test_f1, best_epoch))
        ofile.write("overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n\n" % (
            best_test_acc, best_test_precision, best_test_recall, best_test_f1, best_test_epoch))

    record.close()

    print('Training finished!')
def main():
    embedding = 'glove'
    embedding_path = '/media/xianyang/OS/workspace/ner/glove.6B/glove.6B.100d.txt'
    word_alphabet, char_alphabet, pos_alphabet, \
    chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("/media/xianyang/OS/workspace/ner/NeuroNLP2/data/alphabets/ner_crf/", None)
    char_dim = 30
    num_filters = 30
    window = 3
    mode = 'LSTM'
    hidden_size = 256
    num_layers = 1
    num_labels = ner_alphabet.size()
    tag_space = 128
    p = 0.5
    bigram = True
    embedd_dim = 100
    use_gpu = False

    print(len(word_alphabet.get_content()['instances']))
    print(ner_alphabet.get_content())

    # writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet)
    network = BiRecurrentConvCRF(embedd_dim,
                                 word_alphabet.size(),
                                 char_dim,
                                 char_alphabet.size(),
                                 num_filters,
                                 window,
                                 mode,
                                 hidden_size,
                                 num_layers,
                                 num_labels,
                                 tag_space=tag_space,
                                 embedd_word=None,
                                 p_rnn=p,
                                 bigram=bigram)
    network.load_state_dict(torch.load('temp/23df51_model45'))

    ner_alphabet.add('B-VEH')
    ner_alphabet.add('I-VEH')
    ner_alphabet.add('B-WEA')
    ner_alphabet.add('I-WEA')
    num_new_word = 0

    with open('temp/target.train.conll', 'r') as f:
        sents = []
        sent_buffer = []
        for line in f:
            if len(line) <= 1:
                sents.append(sent_buffer)
                sent_buffer = []
            else:
                id, word, _, _, ner = line.strip().split()
                if word_alphabet.get_index(word) == 0:
                    word_alphabet.add(word)
                    num_new_word += 1
                sent_buffer.append((word_alphabet.get_index(word),
                                    ner_alphabet.get_index(ner)))

    print(len(word_alphabet.get_content()['instances']))
    print(ner_alphabet.get_content())

    init_embed = network.word_embedd.weight.data
    init_embed = np.concatenate(
        (init_embed, np.zeros((num_new_word, embedd_dim))), axis=0)
    network.word_embedd = Embedding(word_alphabet.size(), embedd_dim,
                                    torch.from_numpy(init_embed))

    old_crf = network.crf
    new_crf = ChainCRF(tag_space, ner_alphabet.size(), bigram=bigram)
    trans_matrix = np.zeros((new_crf.num_labels, old_crf.num_labels))
    for i in range(old_crf.num_labels):
        trans_matrix[i, i] = 1
    new_crf.state_nn.weight.data = torch.FloatTensor(
        np.dot(trans_matrix, old_crf.state_nn.weight.data))
    network.crf = new_crf

    target_train_data = conll03_data.read_data_to_variable(
        'temp/target.train.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=False)
    target_dev_data = conll03_data.read_data_to_variable(
        'temp/target.dev.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=False)
    target_test_data = conll03_data.read_data_to_variable(
        'temp/target.test.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=False)

    num_epoch = 50
    batch_size = 32
    num_data = sum(target_train_data[1])
    num_batches = num_data / batch_size + 1
    unk_replace = 0.0
    # optim = SGD(network.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0, nesterov=True)
    optim = Adam(network.parameters(), lr=1e-3)

    for epoch in range(1, num_epoch + 1):
        train_err = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()

        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                target_train_data, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            if batch % 20 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d loss: %.4f, time: %.2fs' % (
                    num_batches, train_err / train_total,
                    time.time() - start_time)
                print(log_info)
                num_back = len(log_info)

        writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                               chunk_alphabet, ner_alphabet)
        os.system('rm temp/output.txt')
        writer.start('temp/output.txt')
        network.eval()
        for batch in conll03_data.iterate_batch_variable(
                target_dev_data, batch_size):
            word, char, pos, chunk, labels, masks, lengths, _ = batch
            preds, _, _ = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            writer.write(word.data.cpu().numpy(),
                         pos.data.cpu().numpy(),
                         chunk.data.cpu().numpy(),
                         preds.cpu().numpy(),
                         labels.data.cpu().numpy(),
                         lengths.cpu().numpy())
        writer.close()

        acc, precision, recall, f1 = evaluate('temp/output.txt')
        log_info = 'dev: %f %f %f %f' % (acc, precision, recall, f1)
        print(log_info)

        if epoch % 10 == 0:
            writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                                   chunk_alphabet, ner_alphabet)
            os.system('rm temp/output.txt')
            writer.start('temp/output.txt')
            network.eval()
            for batch in conll03_data.iterate_batch_variable(
                    target_test_data, batch_size):
                word, char, pos, chunk, labels, masks, lengths, _ = batch
                preds, _, _ = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()

            acc, precision, recall, f1 = evaluate('temp/output.txt')
            log_info = 'test: %f %f %f %f' % (acc, precision, recall, f1)
            print(log_info)

    torch.save(network, 'temp/tuned_0905.pt')
    alphabet_directory = '0905_alphabet/'
    word_alphabet.save(alphabet_directory)
    char_alphabet.save(alphabet_directory)
    pos_alphabet.save(alphabet_directory)
    chunk_alphabet.save(alphabet_directory)
    ner_alphabet.save(alphabet_directory)
Example #6
0
def main():
    # Arguments parser
    parser = argparse.ArgumentParser(
        description='Tuning with DNN Model for NER')
    # Model Hyperparameters
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        default='LSTM')
    parser.add_argument('--encoder_mode',
                        choices=['cnn', 'lstm'],
                        help='Encoder type for sentence encoding',
                        default='lstm')
    parser.add_argument('--char_method',
                        choices=['cnn', 'lstm'],
                        help='Method to create character-level embeddings',
                        required=True)
    parser.add_argument(
        '--hidden_size',
        type=int,
        default=128,
        help='Number of hidden units in RNN for sentence level')
    parser.add_argument('--char_hidden_size',
                        type=int,
                        default=30,
                        help='Output character-level embeddings size')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_layers',
                        type=int,
                        default=1,
                        help='Number of layers of RNN')
    parser.add_argument('--dropout',
                        choices=['std', 'weight_drop'],
                        help='Dropout method',
                        default='weight_drop')
    parser.add_argument('--p_em',
                        type=float,
                        default=0.33,
                        help='dropout rate for input embeddings')
    parser.add_argument('--p_in',
                        type=float,
                        default=0.33,
                        help='dropout rate for input of RNN model')
    parser.add_argument('--p_rnn',
                        nargs=2,
                        type=float,
                        required=True,
                        help='dropout rate for RNN')
    parser.add_argument('--p_out',
                        type=float,
                        default=0.33,
                        help='dropout rate for output layer')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')

    # Data loading and storing params
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--dataset_name',
                        type=str,
                        default='alexa',
                        help='Which dataset to use')
    parser.add_argument('--train',
                        type=str,
                        required=True,
                        help='Path of train set')
    parser.add_argument('--dev',
                        type=str,
                        required=True,
                        help='Path of dev set')
    parser.add_argument('--test',
                        type=str,
                        required=True,
                        help='Path of test set')
    parser.add_argument('--results_folder',
                        type=str,
                        default='results',
                        help='The folder to store results')
    parser.add_argument('--tmp_folder',
                        type=str,
                        default='tmp',
                        help='The folder to store tmp files')
    parser.add_argument('--alphabets_folder',
                        type=str,
                        default='data/alphabets',
                        help='The folder to store alphabets files')
    parser.add_argument('--result_file_name',
                        type=str,
                        default='hyperparameters_tuning',
                        help='File name to store some results')
    parser.add_argument('--result_file_path',
                        type=str,
                        default='results/hyperparameters_tuning',
                        help='File name to store some results')

    # Training parameters
    parser.add_argument('--cuda',
                        action='store_true',
                        help='whether using GPU')
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.001,
                        help='Base learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.95,
                        help='Decay rate of learning rate')
    parser.add_argument('--schedule',
                        type=int,
                        default=3,
                        help='schedule for learning rate decay')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for l2 regularization')
    parser.add_argument('--max_norm',
                        type=float,
                        default=1.,
                        help='Max norm for gradients')
    parser.add_argument('--gpu_id',
                        type=int,
                        nargs='+',
                        required=True,
                        help='which gpu to use for training')

    # Misc
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'alexa'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--restore',
                        action='store_true',
                        help='whether restore from stored parameters')
    parser.add_argument('--save_checkpoint',
                        type=str,
                        default='',
                        help='the path to save the model')
    parser.add_argument('--o_tag',
                        type=str,
                        default='O',
                        help='The default tag for outside tag')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--evaluate_raw_format',
                        action='store_true',
                        help='The tagging format for evaluation')

    args = parser.parse_args()

    logger = get_logger("NERCRF")

    # rename the parameters
    mode = args.mode
    encoder_mode = args.encoder_mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    char_hidden_size = args.char_hidden_size
    char_method = args.char_method
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    max_norm = args.max_norm
    schedule = args.schedule
    dropout = args.dropout
    p_em = args.p_em
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict
    dataset_name = args.dataset_name
    result_file_name = args.result_file_name
    evaluate_raw_format = args.evaluate_raw_format
    o_tag = args.o_tag
    restore = args.restore
    save_checkpoint = args.save_checkpoint
    gpu_id = args.gpu_id
    results_folder = args.results_folder
    tmp_folder = args.tmp_folder
    alphabets_folder = args.alphabets_folder
    use_elmo = False
    p_em_vec = 0.
    result_file_path = args.result_file_path

    score_file = "%s/score_gpu_%s" % (tmp_folder, '-'.join(map(str, gpu_id)))

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)
    if not os.path.exists(tmp_folder):
        os.makedirs(tmp_folder)
    if not os.path.exists(alphabets_folder):
        os.makedirs(alphabets_folder)

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, ner_alphabet = conll03_data.create_alphabets(
        "{}/{}/".format(alphabets_folder, dataset_name),
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict,
        max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')
    print(device)

    data_train = conll03_data.read_data_to_tensor(train_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  ner_alphabet,
                                                  device=device)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data_to_tensor(dev_path,
                                                word_alphabet,
                                                char_alphabet,
                                                ner_alphabet,
                                                device=device)
    data_test = conll03_data.read_data_to_tensor(test_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 ner_alphabet,
                                                 device=device)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform_
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     char_hidden_size,
                                     window,
                                     mode,
                                     encoder_mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     tag_space=tag_space,
                                     embedd_word=word_table,
                                     use_elmo=use_elmo,
                                     p_em_vec=p_em_vec,
                                     p_em=p_em,
                                     p_in=p_in,
                                     p_out=p_out,
                                     p_rnn=p_rnn,
                                     bigram=bigram,
                                     initializer=initializer)
    elif args.dropout == 'var':
        network = BiVarRecurrentConvCRF(embedd_dim,
                                        word_alphabet.size(),
                                        char_dim,
                                        char_alphabet.size(),
                                        char_hidden_size,
                                        window,
                                        mode,
                                        encoder_mode,
                                        hidden_size,
                                        num_layers,
                                        num_labels,
                                        tag_space=tag_space,
                                        embedd_word=word_table,
                                        use_elmo=use_elmo,
                                        p_em_vec=p_em_vec,
                                        p_em=p_em,
                                        p_in=p_in,
                                        p_out=p_out,
                                        p_rnn=p_rnn,
                                        bigram=bigram,
                                        initializer=initializer)
    else:
        network = BiWeightDropRecurrentConvCRF(embedd_dim,
                                               word_alphabet.size(),
                                               char_dim,
                                               char_alphabet.size(),
                                               char_hidden_size,
                                               window,
                                               mode,
                                               encoder_mode,
                                               hidden_size,
                                               num_layers,
                                               num_labels,
                                               tag_space=tag_space,
                                               embedd_word=word_table,
                                               p_em=p_em,
                                               p_in=p_in,
                                               p_out=p_out,
                                               p_rnn=p_rnn,
                                               bigram=bigram,
                                               initializer=initializer)

    network = network.to(device)

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma,
                nesterov=True)
    # optim = Adam(network.parameters(), lr=lr, weight_decay=gamma, amsgrad=True)
    nn.utils.clip_grad_norm_(network.parameters(), max_norm)
    logger.info("Network: %s, encoder_mode=%s, num_layer=%d, hidden=%d, char_hidden_size=%d, char_method=%s, tag_space=%d, crf=%s" % \
        (mode, encoder_mode, num_layers, hidden_size, char_hidden_size, char_method, tag_space, 'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))

    num_batches = num_data // batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    best_test_f1 = 0.0
    best_test_acc = 0.0
    best_test_precision = 0.0
    best_test_recall = 0.0
    best_test_epoch = 0.0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))

        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            _, word, char, labels, masks, lengths = conll03_data.get_batch_tensor(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(_, word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            with torch.no_grad():
                num_inst = word.size(0)
                train_err += loss * num_inst
                train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 20 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        with torch.no_grad():
            network.eval()
            tmp_filename = '%s/gpu_%s_dev' % (tmp_folder, '-'.join(
                map(str, gpu_id)))
            writer.start(tmp_filename)

            for batch in conll03_data.iterate_batch_tensor(
                    data_dev, batch_size):
                _, word, char, labels, masks, lengths = batch
                preds, _ = network.decode(
                    _,
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            acc, precision, recall, f1 = evaluate(tmp_filename, score_file,
                                                  evaluate_raw_format, o_tag)
            print(
                'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                % (acc, precision, recall, f1))

            if dev_f1 < f1:
                dev_f1 = f1
                dev_acc = acc
                dev_precision = precision
                dev_recall = recall
                best_epoch = epoch

                # evaluate on test data when better performance detected
                tmp_filename = '%s/gpu_%s_test' % (tmp_folder, '-'.join(
                    map(str, gpu_id)))
                writer.start(tmp_filename)

                for batch in conll03_data.iterate_batch_tensor(
                        data_test, batch_size):
                    _, word, char, labels, masks, lengths = batch
                    preds, _ = network.decode(
                        _,
                        word,
                        char,
                        target=labels,
                        mask=masks,
                        leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                    writer.write(word.cpu().numpy(),
                                 preds.cpu().numpy(),
                                 labels.cpu().numpy(),
                                 lengths.cpu().numpy())
                writer.close()
                test_acc, test_precision, test_recall, test_f1 = evaluate(
                    tmp_filename, score_file, evaluate_raw_format, o_tag)
                if best_test_f1 < test_f1:
                    best_test_acc, best_test_precision, best_test_recall, best_test_f1 = test_acc, test_precision, test_recall, test_f1
                    best_test_epoch = epoch

            print(
                "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
            print(
                "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (test_acc, test_precision, test_recall, test_f1, best_epoch))
            print(
                "overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (best_test_acc, best_test_precision, best_test_recall,
                   best_test_f1, best_test_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)

    with open(result_file_path, 'a') as ofile:
        ofile.write(
            "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n"
            % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        ofile.write(
            "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n"
            % (test_acc, test_precision, test_recall, test_f1, best_epoch))
        ofile.write(
            "overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n\n"
            % (best_test_acc, best_test_precision, best_test_recall,
               best_test_f1, best_test_epoch))
    print('Training finished!')
def sample():
    network = torch.load('temp/ner_active.pt')
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("active_alphabet/", None)

    unannotated_data = conll03_data.read_data_to_variable(
        'temp/unannotated.conll',
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        chunk_alphabet,
        ner_alphabet,
        use_gpu=False,
        volatile=True)

    annotated = set()
    with open('temp/annotated.conll', 'r') as f:
        sent_buffer = []
        for line in f:
            if len(line) > 1:
                _, word, _, _, _ = line.strip().split()
                sent_buffer.append(word)
            else:
                annotated.add(' '.join(sent_buffer))
                sent_buffer = []
    print('total annotated data: {}'.format(len(annotated)))

    uncertain = []
    max_sents = 100
    max_words = 500

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)
    writer.start('temp/output.txt')
    network.eval()
    tiebreaker = count()
    for batch in conll03_data.iterate_batch_variable(unannotated_data, 32):
        word, char, pos, chunk, labels, masks, lengths, raws = batch
        preds, _, confidence = network.decode(
            word,
            char,
            target=labels,
            mask=masks,
            leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
        writer.write(word.data.cpu().numpy(),
                     pos.data.cpu().numpy(),
                     chunk.data.cpu().numpy(),
                     preds.cpu().numpy(),
                     labels.data.cpu().numpy(),
                     lengths.cpu().numpy())
        for _ in range(confidence.size()[0]):
            heapq.heappush(uncertain,
                           (confidence[_].numpy()[0] / lengths[_],
                            tiebreaker.next(), word[_].data.numpy(), raws[_]))
    writer.close()

    cost_sents = 0
    cost_words = 0
    with open('temp/query.conll', 'w') as q:
        while cost_sents < max_sents and cost_words < max_words and uncertain:
            sample = heapq.heappop(uncertain)
            if len(sample[3]) <= 5:
                continue
            # print(sample[0])
            # print([word_alphabet.get_instance(wid) for wid in sample[2]])
            print(sample[3])
            to_write = []
            for word in sample[3]:
                if is_url(word):
                    word = '<_URL>'
                to_write.append(word.encode('ascii', 'ignore'))
            if ' '.join(to_write) in annotated:
                continue
            for wn, word in enumerate(to_write):
                q.write('{0} {1} -- -- O\n'.format(wn + 1, word))
            q.write('\n')
            cost_sents += 1
            cost_words += len(sample[3])
def retrain(train_path, dev_path):
    network = torch.load('temp/ner_tuned.pt')
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("ner_alphabet/", None)

    num_new_word = 0
    with open(train_path, 'r') as f:
        sents = []
        sent_buffer = []
        for line in f:
            if len(line) <= 1:
                sents.append(sent_buffer)
                sent_buffer = []
            else:
                id, word, _, _, ner = line.strip().split()
                if word_alphabet.get_index(word) == 0:
                    word_alphabet.add(word)
                    num_new_word += 1
                sent_buffer.append((word_alphabet.get_index(word),
                                    ner_alphabet.get_index(ner)))
    print('{} new words.'.format(num_new_word))
    init_embed = network.word_embedd.weight.data
    embedd_dim = init_embed.shape[1]
    init_embed = np.concatenate(
        (init_embed, np.zeros((num_new_word, embedd_dim))), axis=0)
    network.word_embedd = Embedding(word_alphabet.size(), embedd_dim,
                                    torch.from_numpy(init_embed))

    target_train_data = conll03_data.read_data_to_variable(train_path,
                                                           word_alphabet,
                                                           char_alphabet,
                                                           pos_alphabet,
                                                           chunk_alphabet,
                                                           ner_alphabet,
                                                           use_gpu=False,
                                                           volatile=False)

    num_epoch = 50
    batch_size = 20
    num_data = sum(target_train_data[1])
    num_batches = num_data / batch_size + 1
    unk_replace = 0.0
    optim = SGD(network.parameters(),
                lr=0.01,
                momentum=0.9,
                weight_decay=0.0,
                nesterov=True)

    for epoch in range(num_epoch):
        train_err = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()

        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                target_train_data, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            print('train: %d loss: %.4f, time: %.2fs' %
                  (num_batches, train_err / train_total,
                   time.time() - start_time))

    torch.save(network, 'temp/ner_active.pt')
    alphabet_directory = 'active_alphabet/'
    word_alphabet.save(alphabet_directory)
    char_alphabet.save(alphabet_directory)
    pos_alphabet.save(alphabet_directory)
    chunk_alphabet.save(alphabet_directory)
    ner_alphabet.save(alphabet_directory)

    target_dev_data = conll03_data.read_data_to_variable(dev_path,
                                                         word_alphabet,
                                                         char_alphabet,
                                                         pos_alphabet,
                                                         chunk_alphabet,
                                                         ner_alphabet,
                                                         use_gpu=False,
                                                         volatile=False)
    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)
    os.system('rm output.txt')
    writer.start('output.txt')
    network.eval()
    for batch in conll03_data.iterate_batch_variable(target_dev_data,
                                                     batch_size):
        word, char, pos, chunk, labels, masks, lengths, _ = batch
        preds, _, _ = network.decode(
            word,
            char,
            target=labels,
            mask=masks,
            leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
        writer.write(word.data.cpu().numpy(),
                     pos.data.cpu().numpy(),
                     chunk.data.cpu().numpy(),
                     preds.cpu().numpy(),
                     labels.data.cpu().numpy(),
                     lengths.cpu().numpy())
    writer.close()

    acc, precision, recall, f1 = evaluate('output.txt')
    print(acc, precision, recall, f1)
    return acc, precision, recall, f1