Exemplo n.º 1
0
def train(args, weights_matrix):
    """ Training BiLSTMCRF model
    Args:
        args: dict that contains options in command
    """
    sent_vocab = Vocab.load(args['SENT_VOCAB'])
    tag_vocab_ner = Vocab.load(args['TAG_VOCAB_NER'])
    tag_vocab_entity = Vocab.load(args['TAG_VOCAB_ENTITY'])
    method = args['METHOD']
    train_data, dev_data = utils.generate_train_dev_dataset(
        args['TRAIN'], sent_vocab, tag_vocab_ner, tag_vocab_entity)
    print('num of training examples: %d' % (len(train_data)))
    print('num of development examples: %d' % (len(dev_data)))

    max_epoch = int(args['--max-epoch'])
    log_every = int(args['--log-every'])
    validation_every = int(args['--validation-every'])
    model_save_path = args['--model-save-path']
    optimizer_save_path = args['--optimizer-save-path']
    min_dev_loss = float('inf')
    device = torch.device('cuda' if args['--cuda'] else 'cpu')
    patience, decay_num = 0, 0

    model = bilstm_crf.BiLSTMCRF(weights_matrix, sent_vocab,
                                 tag_vocab_ner, tag_vocab_entity,
                                 float(args['--dropout-rate']),
                                 int(args['--embed-size']),
                                 int(args['--hidden-size'])).to(device)
    print(model)
    # for name, param in model.named_parameters():
    #     if 'weight' in name:
    #         nn.init.normal_(param.data, 0, 0.01)
    #     else:
    #         nn.init.constant_(param.data, 0)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))
    train_iter = 0  # train iter num
    record_loss_sum, record_tgt_word_sum, record_batch_size = 0, 0, 0  # sum in one training log
    cum_loss_sum, cum_tgt_word_sum, cum_batch_size = 0, 0, 0  # sum in one validation log
    record_start, cum_start = time.time(), time.time()

    print('start training...')
    for epoch in range(max_epoch):
        for sentences, tags_ner, tags_entity in utils.batch_iter(
                train_data, batch_size=int(args['--batch-size'])):
            train_iter += 1
            current_batch_size = len(sentences)
            sentences, sent_lengths = utils.pad(sentences,
                                                sent_vocab[sent_vocab.PAD],
                                                device)
            tags_ner, _ = utils.pad(tags_ner, tag_vocab_ner[tag_vocab_ner.PAD],
                                    device)
            tags_entity, _ = utils.pad(tags_entity,
                                       tag_vocab_entity[tag_vocab_entity.PAD],
                                       device)

            # back propagation
            optimizer.zero_grad()
            batch_loss = model(sentences, tags_ner, tags_entity, sent_lengths,
                               method)  # shape: (b,)
            loss = batch_loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=float(
                                               args['--clip_max_norm']))
            optimizer.step()

            record_loss_sum += batch_loss.sum().item()
            record_batch_size += current_batch_size
            record_tgt_word_sum += sum(sent_lengths)

            cum_loss_sum += batch_loss.sum().item()
            cum_batch_size += current_batch_size
            cum_tgt_word_sum += sum(sent_lengths)

            if train_iter % log_every == 0:
                print(
                    'log: epoch %d, iter %d, %.1f words/sec, avg_loss %f, time %.1f sec'
                    % (epoch + 1, train_iter, record_tgt_word_sum /
                       (time.time() - record_start), record_loss_sum /
                       record_batch_size, time.time() - record_start))
                record_loss_sum, record_batch_size, record_tgt_word_sum = 0, 0, 0
                record_start = time.time()

            if train_iter % validation_every == 0:
                print(
                    'dev: epoch %d, iter %d, %.1f words/sec, avg_loss %f, time %.1f sec'
                    % (epoch + 1, train_iter, cum_tgt_word_sum /
                       (time.time() - cum_start),
                       cum_loss_sum / cum_batch_size, time.time() - cum_start))
                cum_loss_sum, cum_batch_size, cum_tgt_word_sum = 0, 0, 0

                dev_loss = cal_dev_loss(model, dev_data, 64, sent_vocab,
                                        tag_vocab_ner, tag_vocab_entity,
                                        device, method)
                if dev_loss < min_dev_loss * float(
                        args['--patience-threshold']):
                    min_dev_loss = dev_loss
                    model.save(model_save_path)
                    torch.save(optimizer.state_dict(), optimizer_save_path)
                    print('Reached %d epochs, Save result model to %s' %
                          (epoch, model_save_path))
                    patience = 0
                    # Save the word embeddings
                    print("Saving the model")
                    params = torch.load(
                        model_save_path,
                        map_location=lambda storage, loc: storage)
                    new_weights_matrix = params['state_dict'][
                        'embedding.weight']
                    b = new_weights_matrix.tolist()
                    file_path = "./data/weights_matrix.json"
                    json.dump(b,
                              codecs.open(file_path, 'w', encoding='utf-8'),
                              separators=(',', ':'),
                              sort_keys=True,
                              indent=4)
                else:
                    patience += 1
                    if patience == int(args['--max-patience']):
                        decay_num += 1
                        if decay_num == int(args['--max-decay']):
                            return
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        model = bilstm_crf.BiLSTMCRF.load(
                            weights_matrix, model_save_path, device)
                        optimizer.load_state_dict(
                            torch.load(optimizer_save_path))
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        patience = 0
                print(
                    'dev: epoch %d, iter %d, dev_loss %f, patience %d, decay_num %d'
                    % (epoch + 1, train_iter, dev_loss, patience, decay_num))
                cum_start = time.time()
                if train_iter % log_every == 0:
                    record_start = time.time()
Exemplo n.º 2
0
def train(args):
    """ Training BiLSTMCRF model
    Args:
        args: dict that contains options in command
    """
    sent_vocab = Vocab.load(args['SENT_VOCAB'])
    tag_vocab = Vocab.load(args['TAG_VOCAB'])
    train_data, dev_data = utils.generate_train_dev_dataset(
        args['TRAIN'], sent_vocab, tag_vocab)
    print('num of training examples: %d' % (len(train_data)))
    print('num of development examples: %d' % (len(dev_data)))

    max_epoch = int(args['--max-epoch'])
    log_every = int(args['--log-every'])
    validation_every = int(args['--validation-every'])
    model_save_path = args['--model-save-path']
    optimizer_save_path = args['--optimizer-save-path']
    min_dev_loss = float('inf')
    device = torch.device('cuda' if args['--cuda'] else 'cpu')
    patience, decay_num = 0, 0

    model = bilstm_crf.BiLSTMCRF(sent_vocab, tag_vocab,
                                 float(args['--dropout-rate']),
                                 int(args['--embed-size']),
                                 int(args['--hidden-size'])).to(device)
    for name, param in model.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, 0, 0.01)
        else:
            nn.init.constant_(param.data, 0)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))
    train_iter = 0  # train iter num
    record_loss_sum, record_tgt_word_sum, record_batch_size = 0, 0, 0  # sum in one training log
    cum_loss_sum, cum_tgt_word_sum, cum_batch_size = 0, 0, 0  # sum in one validation log
    record_start, cum_start = time.time(), time.time()

    print('start training...')
    for epoch in range(max_epoch):
        for sentences, tags in utils.batch_iter(train_data,
                                                batch_size=int(
                                                    args['--batch-size'])):
            train_iter += 1
            current_batch_size = len(sentences)
            sentences, sent_lengths = utils.pad(sentences,
                                                sent_vocab[sent_vocab.PAD],
                                                device)
            tags, _ = utils.pad(tags, tag_vocab[tag_vocab.PAD], device)

            # back propagation
            optimizer.zero_grad()
            batch_loss = model(sentences, tags, sent_lengths)  # shape: (b,)
            loss = batch_loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=float(
                                               args['--clip_max_norm']))
            optimizer.step()

            record_loss_sum += batch_loss.sum().item()
            record_batch_size += current_batch_size
            record_tgt_word_sum += sum(sent_lengths)

            cum_loss_sum += batch_loss.sum().item()
            cum_batch_size += current_batch_size
            cum_tgt_word_sum += sum(sent_lengths)

            if train_iter % log_every == 0:
                print(
                    'log: epoch %d, iter %d, %.1f words/sec, avg_loss %f, time %.1f sec'
                    % (epoch + 1, train_iter, record_tgt_word_sum /
                       (time.time() - record_start), record_loss_sum /
                       record_batch_size, time.time() - record_start))
                record_loss_sum, record_batch_size, record_tgt_word_sum = 0, 0, 0
                record_start = time.time()

            if train_iter % validation_every == 0:
                print(
                    'dev: epoch %d, iter %d, %.1f words/sec, avg_loss %f, time %.1f sec'
                    % (epoch + 1, train_iter, cum_tgt_word_sum /
                       (time.time() - cum_start),
                       cum_loss_sum / cum_batch_size, time.time() - cum_start))
                cum_loss_sum, cum_batch_size, cum_tgt_word_sum = 0, 0, 0

                dev_loss = cal_dev_loss(model, dev_data, 64, sent_vocab,
                                        tag_vocab, device)
                if dev_loss < min_dev_loss * float(
                        args['--patience-threshold']):
                    min_dev_loss = dev_loss
                    model.save(model_save_path)
                    torch.save(optimizer.state_dict(), optimizer_save_path)
                    patience = 0
                else:
                    patience += 1
                    if patience == int(args['--max-patience']):
                        decay_num += 1
                        if decay_num == int(args['--max-decay']):
                            print('Early stop. Save result model to %s' %
                                  model_save_path)
                            return
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        model = bilstm_crf.BiLSTMCRF.load(
                            model_save_path, device)
                        optimizer.load_state_dict(
                            torch.load(optimizer_save_path))
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        patience = 0
                print(
                    'dev: epoch %d, iter %d, dev_loss %f, patience %d, decay_num %d'
                    % (epoch + 1, train_iter, dev_loss, patience, decay_num))
                cum_start = time.time()
                if train_iter % log_every == 0:
                    record_start = time.time()
    print('Reached %d epochs, Save result model to %s' %
          (max_epoch, model_save_path))
Exemplo n.º 3
0
def train(args):
    """ Training BiLSTMCRF model
    Args:
        args: dict that contains options in command
    """
    sent_vocab = Vocab.load(args.SENT_VOCAB)
    tag_vocab = Vocab.load(args.TAG_VOCAB)
    train_data, dev_data = utils.generate_train_dev_dataset(
        args.TRAIN, sent_vocab, tag_vocab)
    print('num of training examples: %d' % (len(train_data)))
    print('num of development examples: %d' % (len(dev_data)))

    max_epoch = int(args.max_epoch)
    log_every = int(args.log_every)
    validation_every = int(args.validation_every)
    model_save_path = args.model_save_path
    optimizer_save_path = args.optimizer_save_path
    min_dev_loss = float('inf')
    device = torch.device('cuda' if args.cuda else 'cpu')
    # print('cuda is available: ', torch.cuda.is_available())
    # print('cuda device count: ', torch.cuda.device_count())
    # print('cuda device name: ', torch.cuda.get_device_name(0))
    # device = torch.device(device)
    patience, decay_num = 0, 0

    # 현재 미사용 word2vec
    ko_model = gensim.models.Word2Vec.load(args.word2vec_path)
    word2vec_matrix = ko_model.wv.vectors

    model = bilstm_crf.BiLSTMCRF(sent_vocab, tag_vocab, word2vec_matrix,
                                 float(args.dropout_rate),
                                 int(args.embed_size),
                                 int(args.hidden_size)).to(device)
    for name, param in model.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, 0, 0.01)
        else:
            nn.init.constant_(param.data, 0)

    # optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr))
    optimizer = torch.optim.RMSprop(model.parameters(), lr=float(args.lr))
    train_iter = 0  # train iter num
    record_loss_sum, record_tgt_word_sum, record_batch_size = 0, 0, 0  # sum in one training log
    cum_loss_sum, cum_tgt_word_sum, cum_batch_size = 0, 0, 0  # sum in one validation log
    record_start, cum_start = time.time(), time.time()

    print('start training...')
    for epoch in range(max_epoch):
        n_correct, n_total = 0, 0
        for sentences, tags in utils.batch_iter(train_data,
                                                batch_size=int(
                                                    args.batch_size)):
            train_iter += 1
            current_batch_size = len(sentences)
            sentences, sent_lengths = utils.pad(sentences,
                                                sent_vocab[sent_vocab.PAD],
                                                device)
            tags, _ = utils.pad(tags, tag_vocab[tag_vocab.PAD], device)

            # back propagation
            optimizer.zero_grad()
            batch_loss = model(sentences, tags, sent_lengths)  # shape: (b,)
            loss = batch_loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=float(args.clip_max_norm))
            optimizer.step()

            record_loss_sum += batch_loss.sum().item()
            record_batch_size += current_batch_size
            record_tgt_word_sum += sum(sent_lengths)

            cum_loss_sum += batch_loss.sum().item()
            cum_batch_size += current_batch_size
            cum_tgt_word_sum += sum(sent_lengths)

            if train_iter % log_every == 0:
                print(
                    'log: epoch %d, iter %d, %.1f words/sec, avg_loss %f, time %.1f sec'
                    % (epoch + 1, train_iter, record_tgt_word_sum /
                       (time.time() - record_start), record_loss_sum /
                       record_batch_size, time.time() - record_start))
                record_loss_sum, record_batch_size, record_tgt_word_sum = 0, 0, 0
                record_start = time.time()

            if train_iter % validation_every == 0:
                print(
                    'dev: epoch %d, iter %d, %.1f words/sec, avg_loss %f, time %.1f sec'
                    % (epoch + 1, train_iter, cum_tgt_word_sum /
                       (time.time() - cum_start),
                       cum_loss_sum / cum_batch_size, time.time() - cum_start))
                cum_loss_sum, cum_batch_size, cum_tgt_word_sum = 0, 0, 0

                dev_loss = cal_dev_loss(model, dev_data, 64, sent_vocab,
                                        tag_vocab, device)
                cal_f1_score(model, dev_data, 64, sent_vocab, tag_vocab,
                             device)
                if dev_loss < min_dev_loss * float(args.patience_threshold):
                    min_dev_loss = dev_loss
                    model.save(model_save_path)
                    torch.save(optimizer.state_dict(), optimizer_save_path)
                    patience = 0
                else:
                    patience += 1
                    if patience == int(args.max_patience):
                        decay_num += 1
                        if decay_num == int(args.max_decay):
                            print('Early stop. Save result model to %s' %
                                  model_save_path)
                            return
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args.lr_decay)
                        model = bilstm_crf.BiLSTMCRF.load(
                            model_save_path, device)
                        optimizer.load_state_dict(
                            torch.load(optimizer_save_path))
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        patience = 0
                print(
                    'dev: epoch %d, iter %d, dev_loss %f, patience %d, decay_num %d'
                    % (epoch + 1, train_iter, dev_loss, patience, decay_num))
                cum_start = time.time()
                if train_iter % log_every == 0:
                    record_start = time.time()
    print('Reached %d epochs, Save result model to %s' %
          (max_epoch, model_save_path))