Exemple #1
0
def train(data, opt, fold_idx):

    model = SeqModel(data, opt)

    optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.l2)

    if opt.tune_wordemb == False:
        my_utils.freeze_net(model.word_hidden.wordrep.word_embedding)

    best_dev_f = -10
    best_dev_p = -10
    best_dev_r = -10

    bad_counter = 0

    for idx in range(opt.iter):
        epoch_start = time.time()

        if opt.elmo:
            my_utils.shuffle(data.train_texts, data.train_Ids)
        else:
            random.shuffle(data.train_Ids)

        model.train()
        model.zero_grad()
        batch_size = opt.batch_size
        train_num = len(data.train_Ids)
        total_batch = train_num // batch_size + 1

        for batch_id in range(total_batch):

            start = batch_id * batch_size
            end = (batch_id + 1) * batch_size
            if end > train_num:
                end = train_num
            instance = data.train_Ids[start:end]
            if opt.elmo:
                instance_text = data.train_texts[start:end]
            else:
                instance_text = None
            if not instance:
                continue

            batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, batch_features, batch_text = batchify_with_label(
                data, instance, instance_text, opt.gpu)

            loss, tag_seq = model.neg_log_likelihood_loss(
                batch_word, batch_wordlen, batch_char, batch_charlen,
                batch_charrecover, batch_label, mask, batch_features,
                batch_text)

            loss.backward()

            if opt.gradient_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               opt.gradient_clip)
            optimizer.step()
            model.zero_grad()

        epoch_finish = time.time()
        logging.info("epoch: %s training finished. Time: %.2fs" %
                     (idx, epoch_finish - epoch_start))

        if opt.dev_file:
            _, _, p, r, f, _, _ = evaluate(data, opt, model, "dev", True)
            logging.info("Dev: p: %.4f, r: %.4f, f: %.4f" % (p, r, f))
        else:
            f = best_dev_f

        if f > best_dev_f:
            logging.info("Exceed previous best f score on dev: %.4f" %
                         (best_dev_f))

            if fold_idx is None:
                torch.save(model.state_dict(),
                           os.path.join(opt.output, "model.pkl"))
            else:
                torch.save(
                    model.state_dict(),
                    os.path.join(opt.output,
                                 "model_{}.pkl".format(fold_idx + 1)))

            best_dev_f = f
            best_dev_p = p
            best_dev_r = r

            # if opt.test_file:
            #     _, _, p, r, f, _, _ = evaluate(data, opt, model, "test", True, opt.nbest)
            #     logging.info("Test: p: %.4f, r: %.4f, f: %.4f" % (p, r, f))

            bad_counter = 0
        else:
            bad_counter += 1

        if len(opt.dev_file) != 0 and bad_counter >= opt.patience:
            logging.info('Early Stop!')
            break

    logging.info("train finished")

    if len(opt.dev_file) == 0:
        torch.save(model.state_dict(), os.path.join(opt.output, "model.pkl"))

    return best_dev_p, best_dev_r, best_dev_f
Exemple #2
0
def train(train_data, dev_data, test_data, d, dictionary, dictionary_reverse,
          opt, fold_idx, isMeddra_dict):
    logging.info("train the vsm-based normalization model ...")

    external_train_data = []
    if d.config.get('norm_ext_corpus') is not None:
        for k, v in d.config['norm_ext_corpus'].items():
            if k == 'tac':
                external_train_data.extend(
                    load_data_fda(v['path'], True, v.get('types'),
                                  v.get('types'), False, True))
            else:
                raise RuntimeError("not support external corpus")
    if len(external_train_data) != 0:
        train_data.extend(external_train_data)

    logging.info("build alphabet ...")
    word_alphabet = Alphabet('word')
    norm_utils.build_alphabet_from_dict(word_alphabet, dictionary,
                                        isMeddra_dict)
    norm_utils.build_alphabet(word_alphabet, train_data)
    if opt.dev_file:
        norm_utils.build_alphabet(word_alphabet, dev_data)
    if opt.test_file:
        norm_utils.build_alphabet(word_alphabet, test_data)
    norm_utils.fix_alphabet(word_alphabet)
    logging.info("alphabet size {}".format(word_alphabet.size()))

    if d.config.get('norm_emb') is not None:
        logging.info("load pretrained word embedding ...")
        pretrain_word_embedding, word_emb_dim = build_pretrain_embedding(
            d.config.get('norm_emb'), word_alphabet, opt.word_emb_dim, False)
        word_embedding = nn.Embedding(word_alphabet.size(),
                                      word_emb_dim,
                                      padding_idx=0)
        word_embedding.weight.data.copy_(
            torch.from_numpy(pretrain_word_embedding))
        embedding_dim = word_emb_dim
    else:
        logging.info("randomly initialize word embedding ...")
        word_embedding = nn.Embedding(word_alphabet.size(),
                                      d.word_emb_dim,
                                      padding_idx=0)
        word_embedding.weight.data.copy_(
            torch.from_numpy(
                random_embedding(word_alphabet.size(), d.word_emb_dim)))
        embedding_dim = d.word_emb_dim

    dict_alphabet = Alphabet('dict')
    norm_utils.init_dict_alphabet(dict_alphabet, dictionary)
    norm_utils.fix_alphabet(dict_alphabet)

    logging.info("init_vector_for_dict")
    poses, poses_lengths = init_vector_for_dict(word_alphabet, dict_alphabet,
                                                dictionary, isMeddra_dict)

    vsm_model = VsmNormer(word_alphabet, word_embedding, embedding_dim,
                          dict_alphabet, poses, poses_lengths)

    logging.info("generate instances for training ...")
    train_X = []
    train_Y = []

    for doc in train_data:
        if isMeddra_dict:
            temp_X, temp_Y = generate_instances(doc.entities, word_alphabet,
                                                dict_alphabet)
        else:
            temp_X, temp_Y = generate_instances_ehr(doc.entities,
                                                    word_alphabet,
                                                    dict_alphabet,
                                                    dictionary_reverse)
        train_X.extend(temp_X)
        train_Y.extend(temp_Y)

    train_loader = DataLoader(MyDataset(train_X, train_Y),
                              opt.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    optimizer = optim.Adam(vsm_model.parameters(),
                           lr=opt.lr,
                           weight_decay=opt.l2)

    if opt.tune_wordemb == False:
        freeze_net(vsm_model.word_embedding)

    if d.config['norm_vsm_pretrain'] == '1':
        dict_pretrain(dictionary, dictionary_reverse, d, isMeddra_dict,
                      optimizer, vsm_model)

    best_dev_f = -10
    best_dev_p = -10
    best_dev_r = -10

    bad_counter = 0

    logging.info("start training ...")

    for idx in range(opt.iter):
        epoch_start = time.time()

        vsm_model.train()

        train_iter = iter(train_loader)
        num_iter = len(train_loader)

        sum_loss = 0

        correct, total = 0, 0

        for i in range(num_iter):

            x, lengths, y = next(train_iter)

            l, y_pred = vsm_model.forward_train(x, lengths, y)

            sum_loss += l.item()

            l.backward()

            if opt.gradient_clip > 0:
                torch.nn.utils.clip_grad_norm_(vsm_model.parameters(),
                                               opt.gradient_clip)
            optimizer.step()
            vsm_model.zero_grad()

            total += y.size(0)
            _, pred = torch.max(y_pred, 1)
            correct += (pred == y).sum().item()

        epoch_finish = time.time()
        accuracy = 100.0 * correct / total
        logging.info(
            "epoch: %s training finished. Time: %.2fs. loss: %.4f Accuracy %.2f"
            % (idx, epoch_finish - epoch_start, sum_loss / num_iter, accuracy))

        if opt.dev_file:
            p, r, f = norm_utils.evaluate(dev_data, dictionary,
                                          dictionary_reverse, vsm_model, None,
                                          None, d, isMeddra_dict)
            logging.info("Dev: p: %.4f, r: %.4f, f: %.4f" % (p, r, f))
        else:
            f = best_dev_f

        if f > best_dev_f:
            logging.info("Exceed previous best f score on dev: %.4f" %
                         (best_dev_f))

            if fold_idx is None:
                torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl"))
            else:
                torch.save(
                    vsm_model,
                    os.path.join(opt.output,
                                 "vsm_{}.pkl".format(fold_idx + 1)))

            best_dev_f = f
            best_dev_p = p
            best_dev_r = r

            bad_counter = 0
        else:
            bad_counter += 1

        if len(opt.dev_file) != 0 and bad_counter >= opt.patience:
            logging.info('Early Stop!')
            break

    logging.info("train finished")

    if len(opt.dev_file) == 0:
        torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl"))

    return best_dev_p, best_dev_r, best_dev_f
def joint_train(data, old_data, opt):

    if not os.path.exists(opt.output):
        os.makedirs(opt.output)

    if opt.pretrained_model_dir != 'None':
        seq_model = SeqModel(data)
        if opt.test_in_cpu:
            seq_model.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        'ner_model.pkl'),
                           map_location='cpu'))
        else:
            seq_model.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 'ner_model.pkl')))

        if (data.label_alphabet_size != seq_model.crf.tagset_size)\
                or (data.HP_hidden_dim != seq_model.hidden2tag.weight.size(1)):
            raise RuntimeError("ner_model not compatible")

        seq_wordseq = WordSequence(data, False, True, True, True)

        if ((data.word_emb_dim != seq_wordseq.wordrep.word_embedding.embedding_dim)\
            or (data.char_emb_dim != seq_wordseq.wordrep.char_feature.char_embeddings.embedding_dim)\
            or (data.feature_emb_dims[0] != seq_wordseq.wordrep.feature_embedding_dims[0])\
            or (data.feature_emb_dims[1] != seq_wordseq.wordrep.feature_embedding_dims[1])):
            raise RuntimeError("ner_wordseq not compatible")

        old_seq_wordseq = WordSequence(old_data, False, True, True, True)
        if opt.test_in_cpu:
            old_seq_wordseq.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        'ner_wordseq.pkl'),
                           map_location='cpu'))
        else:
            old_seq_wordseq.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 'ner_wordseq.pkl')))

        # sd = old_seq_wordseq.lstm.state_dict()

        for old, new in zip(old_seq_wordseq.lstm.parameters(),
                            seq_wordseq.lstm.parameters()):
            new.data.copy_(old)

        vocab_size = old_seq_wordseq.wordrep.word_embedding.num_embeddings
        seq_wordseq.wordrep.word_embedding.weight.data[
            0:
            vocab_size, :] = old_seq_wordseq.wordrep.word_embedding.weight.data[
                0:vocab_size, :]

        vocab_size = old_seq_wordseq.wordrep.char_feature.char_embeddings.num_embeddings
        seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[
            0:
            vocab_size, :] = old_seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[
                0:vocab_size, :]

        for i, feature_embedding in enumerate(
                old_seq_wordseq.wordrep.feature_embeddings):
            vocab_size = feature_embedding.num_embeddings
            seq_wordseq.wordrep.feature_embeddings[i].weight.data[
                0:vocab_size, :] = feature_embedding.weight.data[
                    0:vocab_size, :]

        # for word in data.word_alphabet.iteritems():
        #
        #     old_seq_wordseq.wordrep.word_embedding.weight.data data.word_alphabet.get_index(word)
        #

        classify_wordseq = WordSequence(data, True, False, True, False)

        if ((data.word_emb_dim != classify_wordseq.wordrep.word_embedding.embedding_dim)\
            or (data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']] != classify_wordseq.wordrep.position1_emb.embedding_dim)\
            or (data.feature_emb_dims[1] != classify_wordseq.wordrep.feature_embedding_dims[0])):
            raise RuntimeError("re_wordseq not compatible")

        old_classify_wordseq = WordSequence(old_data, True, False, True, False)
        if opt.test_in_cpu:
            old_classify_wordseq.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        're_wordseq.pkl'),
                           map_location='cpu'))
        else:
            old_classify_wordseq.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 're_wordseq.pkl')))

        for old, new in zip(old_classify_wordseq.lstm.parameters(),
                            classify_wordseq.lstm.parameters()):
            new.data.copy_(old)

        vocab_size = old_classify_wordseq.wordrep.word_embedding.num_embeddings
        classify_wordseq.wordrep.word_embedding.weight.data[
            0:
            vocab_size, :] = old_classify_wordseq.wordrep.word_embedding.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_wordseq.wordrep.position1_emb.num_embeddings
        classify_wordseq.wordrep.position1_emb.weight.data[
            0:
            vocab_size, :] = old_classify_wordseq.wordrep.position1_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_wordseq.wordrep.position2_emb.num_embeddings
        classify_wordseq.wordrep.position2_emb.weight.data[
            0:
            vocab_size, :] = old_classify_wordseq.wordrep.position2_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_wordseq.wordrep.feature_embeddings[
            0].num_embeddings
        classify_wordseq.wordrep.feature_embeddings[0].weight.data[
            0:vocab_size, :] = old_classify_wordseq.wordrep.feature_embeddings[
                0].weight.data[0:vocab_size, :]

        classify_model = ClassifyModel(data)

        old_classify_model = ClassifyModel(old_data)
        if opt.test_in_cpu:
            old_classify_model.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        're_model.pkl'),
                           map_location='cpu'))
        else:
            old_classify_model.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 're_model.pkl')))

        if (data.re_feature_alphabet_sizes[data.re_feature_name2id['[RELATION]']] != old_classify_model.linear.weight.size(0)
            or (data.re_feature_emb_dims[data.re_feature_name2id['[ENTITY_TYPE]']] != old_classify_model.entity_type_emb.embedding_dim) \
                or (data.re_feature_emb_dims[
                        data.re_feature_name2id['[ENTITY]']] != old_classify_model.entity_emb.embedding_dim)
                or (data.re_feature_emb_dims[
                        data.re_feature_name2id['[TOKEN_NUM]']] != old_classify_model.tok_num_betw_emb.embedding_dim) \
                or (data.re_feature_emb_dims[
                        data.re_feature_name2id['[ENTITY_NUM]']] != old_classify_model.et_num_emb.embedding_dim) \
                ):
            raise RuntimeError("re_model not compatible")

        vocab_size = old_classify_model.entity_type_emb.num_embeddings
        classify_model.entity_type_emb.weight.data[
            0:vocab_size, :] = old_classify_model.entity_type_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_model.entity_emb.num_embeddings
        classify_model.entity_emb.weight.data[
            0:vocab_size, :] = old_classify_model.entity_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_model.tok_num_betw_emb.num_embeddings
        classify_model.tok_num_betw_emb.weight.data[
            0:vocab_size, :] = old_classify_model.tok_num_betw_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_model.et_num_emb.num_embeddings
        classify_model.et_num_emb.weight.data[
            0:vocab_size, :] = old_classify_model.et_num_emb.weight.data[
                0:vocab_size, :]

    else:
        seq_model = SeqModel(data)
        seq_wordseq = WordSequence(data, False, True, True, True)

        classify_wordseq = WordSequence(data, True, False, True, False)
        classify_model = ClassifyModel(data)

    iter_parameter = itertools.chain(
        *map(list, [seq_wordseq.parameters(),
                    seq_model.parameters()]))
    seq_optimizer = optim.Adam(iter_parameter,
                               lr=data.HP_lr,
                               weight_decay=data.HP_l2)
    iter_parameter = itertools.chain(*map(
        list, [classify_wordseq.parameters(),
               classify_model.parameters()]))
    classify_optimizer = optim.Adam(iter_parameter,
                                    lr=data.HP_lr,
                                    weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(seq_wordseq.wordrep.word_embedding)
        my_utils.freeze_net(classify_wordseq.wordrep.word_embedding)

    re_X_positive = []
    re_Y_positive = []
    re_X_negative = []
    re_Y_negative = []
    relation_vocab = data.re_feature_alphabets[
        data.re_feature_name2id['[RELATION]']]
    my_collate = my_utils.sorted_collate1
    for i in range(len(data.re_train_X)):
        x = data.re_train_X[i]
        y = data.re_train_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            re_X_positive.append(x)
            re_Y_positive.append(y)
        else:
            re_X_negative.append(x)
            re_Y_negative.append(y)

    re_dev_loader = DataLoader(my_utils.RelationDataset(
        data.re_dev_X, data.re_dev_Y),
                               data.HP_batch_size,
                               shuffle=False,
                               collate_fn=my_collate)
    # re_test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate)

    best_ner_score = -1
    best_re_score = -1
    count_performance_not_grow = 0

    for idx in range(data.HP_iteration):
        epoch_start = time.time()

        seq_wordseq.train()
        seq_wordseq.zero_grad()
        seq_model.train()
        seq_model.zero_grad()

        classify_wordseq.train()
        classify_wordseq.zero_grad()
        classify_model.train()
        classify_model.zero_grad()

        batch_size = data.HP_batch_size

        random.shuffle(data.train_Ids)
        ner_train_num = len(data.train_Ids)
        ner_total_batch = ner_train_num // batch_size + 1

        re_train_loader, re_train_iter = makeRelationDataset(
            re_X_positive, re_Y_positive, re_X_negative, re_Y_negative,
            data.unk_ratio, True, my_collate, data.HP_batch_size)
        re_total_batch = len(re_train_loader)

        total_batch = max(ner_total_batch, re_total_batch)
        min_batch = min(ner_total_batch, re_total_batch)

        for batch_id in range(total_batch):

            if batch_id < ner_total_batch:
                start = batch_id * batch_size
                end = (batch_id + 1) * batch_size
                if end > ner_train_num:
                    end = ner_train_num
                instance = data.train_Ids[start:end]
                batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \
                    batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                hidden = seq_wordseq.forward(batch_word, batch_features,
                                             batch_wordlen, batch_char,
                                             batch_charlen, batch_charrecover,
                                             None, None)
                hidden_adv = None
                loss, tag_seq = seq_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_label, mask)
                loss.backward()
                seq_optimizer.step()
                seq_wordseq.zero_grad()
                seq_model.zero_grad()

            if batch_id < re_total_batch:
                [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
                 batch_char, batch_charlen, batch_charrecover, \
                 position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
                 tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(
                    re_train_loader, re_train_iter)

                hidden = classify_wordseq.forward(batch_word, batch_features,
                                                  batch_wordlen, batch_char,
                                                  batch_charlen,
                                                  batch_charrecover,
                                                  position1_seq_tensor,
                                                  position2_seq_tensor)
                hidden_adv = None
                loss, pred = classify_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_wordlen, e1_token, e1_length,
                    e2_token, e2_length, e1_type, e2_type, tok_num_betw,
                    et_num, targets)
                loss.backward()
                classify_optimizer.step()
                classify_wordseq.zero_grad()
                classify_model.zero_grad()

        epoch_finish = time.time()
        logging.info("epoch: %s training finished. Time: %.2fs" %
                     (idx, epoch_finish - epoch_start))

        _, _, _, _, ner_score, _, _ = ner.evaluate(data, seq_wordseq,
                                                   seq_model, "dev")
        logging.info("ner evaluate: f: %.4f" % (ner_score))
        if ner_score > best_ner_score:
            logging.info("new best score: ner: %.4f" % (ner_score))
            best_ner_score = ner_score

            torch.save(seq_wordseq.state_dict(),
                       os.path.join(opt.output, 'ner_wordseq.pkl'))
            torch.save(seq_model.state_dict(),
                       os.path.join(opt.output, 'ner_model.pkl'))

            count_performance_not_grow = 0

            # _, _, _, _, test_ner_score, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test")
            # logging.info("ner evaluate on test: f: %.4f" % (test_ner_score))

        else:
            count_performance_not_grow += 1

        re_score = relation_extraction.evaluate(classify_wordseq,
                                                classify_model, re_dev_loader)
        logging.info("re evaluate: f: %.4f" % (re_score))
        if re_score > best_re_score:
            logging.info("new best score: re: %.4f" % (re_score))
            best_re_score = re_score

            torch.save(classify_wordseq.state_dict(),
                       os.path.join(opt.output, 're_wordseq.pkl'))
            torch.save(classify_model.state_dict(),
                       os.path.join(opt.output, 're_model.pkl'))

            count_performance_not_grow = 0

            # test_re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_test_loader)
            # logging.info("re evaluate on test: f: %.4f" % (test_re_score))
        else:
            count_performance_not_grow += 1

        if count_performance_not_grow > 2 * data.patience:
            logging.info("early stop")
            break

    logging.info("train finished")
def train(data, dir):

    my_collate = my_utils.sorted_collate1

    train_loader, train_iter = makeDatasetWithoutUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], True, my_collate, data.HP_batch_size)
    num_iter = len(train_loader)
    unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], my_collate, data.unk_ratio, data.HP_batch_size)

    test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y),
                              data.HP_batch_size, shuffle=False, collate_fn=my_collate)

    wordseq = WordSequence(data, True, False, False)
    model = ClassifyModel(data)
    if torch.cuda.is_available():
        model = model.cuda(data.HP_gpu)

    if opt.self_adv == 'grad':
        wordseq_adv = WordSequence(data, True, False, False)
    elif opt.self_adv == 'label':
        wordseq_adv = WordSequence(data, True, False, False)
        model_adv = ClassifyModel(data)
        if torch.cuda.is_available():
            model_adv = model_adv.cuda(data.HP_gpu)
    else:
        wordseq_adv = None

    if opt.self_adv == 'grad':
        iter_parameter = itertools.chain(
            *map(list, [wordseq.parameters(), wordseq_adv.parameters(), model.parameters()]))
        optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)
    elif opt.self_adv == 'label':
        iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()]))
        optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)
        iter_parameter = itertools.chain(*map(list, [wordseq_adv.parameters(), model_adv.parameters()]))
        optimizer_adv = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)

    else:
        iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()]))
        optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(wordseq.wordrep.word_embedding)
        if opt.self_adv != 'no':
            my_utils.freeze_net(wordseq_adv.wordrep.word_embedding)

    best_acc = 0.0
    logging.info("start training ...")
    for epoch in range(data.max_epoch):
        wordseq.train()
        wordseq.zero_grad()
        model.train()
        model.zero_grad()
        if opt.self_adv == 'grad':
            wordseq_adv.train()
            wordseq_adv.zero_grad()
        elif opt.self_adv == 'label':
            wordseq_adv.train()
            wordseq_adv.zero_grad()
            model_adv.train()
            model_adv.zero_grad()
        correct, total = 0, 0

        for i in range(num_iter):
            [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
            batch_char, batch_charlen, batch_charrecover, \
            position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
            tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(train_loader, train_iter)


            if opt.self_adv == 'grad':
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets)
                loss.backward()
                my_utils.reverse_grad(wordseq_adv)
                optimizer.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model.zero_grad()

            elif opt.self_adv == 'label' :
                wordseq.unfreeze_net()
                wordseq_adv.freeze_net()
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets)
                loss.backward()
                optimizer.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model.zero_grad()

                wordseq.freeze_net()
                wordseq_adv.unfreeze_net()
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                loss_adv, _ = model_adv.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets_permute)
                loss_adv.backward()
                optimizer_adv.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model_adv.zero_grad()

            else:
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = None
                loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets)
                loss.backward()
                optimizer.step()
                wordseq.zero_grad()
                model.zero_grad()



            total += targets.size(0)
            correct += (pred == targets).sum().item()


            [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
            batch_char, batch_charlen, batch_charrecover, \
            position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
            tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(unk_loader, unk_iter)

            hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                     batch_charrecover, position1_seq_tensor, position2_seq_tensor)
            hidden_adv = None
            loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                       e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                       tok_num_betw, et_num, targets)
            loss.backward()
            optimizer.step()
            wordseq.zero_grad()
            model.zero_grad()


        unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y,
                                                  data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']],
                                                  my_collate, data.unk_ratio, data.HP_batch_size)

        logging.info('epoch {} end'.format(epoch))
        logging.info('Train Accuracy: {}%'.format(100.0 * correct / total))

        test_accuracy = evaluate(wordseq, model, test_loader)
        # test_accuracy = evaluate(m, test_loader)
        logging.info('Test Accuracy: {}%'.format(test_accuracy))

        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(wordseq.state_dict(), os.path.join(dir, 'wordseq.pkl'))
            torch.save(model.state_dict(), '{}/model.pkl'.format(dir))
            if opt.self_adv == 'grad':
                torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl'))
            elif opt.self_adv == 'label':
                torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl'))
                torch.save(model_adv.state_dict(), os.path.join(dir, 'model_adv.pkl'))
            logging.info('New best accuracy: {}'.format(best_acc))


    logging.info("training completed")
Exemple #5
0
def train(train_data, dev_data, test_data, d, dictionary, dictionary_reverse,
          opt, fold_idx, isMeddra_dict):
    logging.info("train the ensemble normalization model ...")

    external_train_data = []
    if d.config.get('norm_ext_corpus') is not None:
        for k, v in d.config['norm_ext_corpus'].items():
            if k == 'tac':
                external_train_data.extend(
                    load_data_fda(v['path'], True, v.get('types'),
                                  v.get('types'), False, True))
            else:
                raise RuntimeError("not support external corpus")
    if len(external_train_data) != 0:
        train_data.extend(external_train_data)

    logging.info("build alphabet ...")
    word_alphabet = Alphabet('word')
    norm_utils.build_alphabet_from_dict(word_alphabet, dictionary,
                                        isMeddra_dict)
    norm_utils.build_alphabet(word_alphabet, train_data)
    if opt.dev_file:
        norm_utils.build_alphabet(word_alphabet, dev_data)
    if opt.test_file:
        norm_utils.build_alphabet(word_alphabet, test_data)
    norm_utils.fix_alphabet(word_alphabet)

    if d.config.get('norm_emb') is not None:
        logging.info("load pretrained word embedding ...")
        pretrain_word_embedding, word_emb_dim = build_pretrain_embedding(
            d.config.get('norm_emb'), word_alphabet, opt.word_emb_dim, False)
        word_embedding = nn.Embedding(word_alphabet.size(),
                                      word_emb_dim,
                                      padding_idx=0)
        word_embedding.weight.data.copy_(
            torch.from_numpy(pretrain_word_embedding))
        embedding_dim = word_emb_dim
    else:
        logging.info("randomly initialize word embedding ...")
        word_embedding = nn.Embedding(word_alphabet.size(),
                                      d.word_emb_dim,
                                      padding_idx=0)
        word_embedding.weight.data.copy_(
            torch.from_numpy(
                random_embedding(word_alphabet.size(), d.word_emb_dim)))
        embedding_dim = d.word_emb_dim

    dict_alphabet = Alphabet('dict')
    norm_utils.init_dict_alphabet(dict_alphabet, dictionary)
    norm_utils.fix_alphabet(dict_alphabet)

    # rule
    logging.info("init rule-based normer")
    multi_sieve.init(opt, train_data, d, dictionary, dictionary_reverse,
                     isMeddra_dict)

    if opt.ensemble == 'learn':
        logging.info("init ensemble normer")
        poses = vsm.init_vector_for_dict(word_alphabet, dict_alphabet,
                                         dictionary, isMeddra_dict)
        ensemble_model = Ensemble(word_alphabet, word_embedding, embedding_dim,
                                  dict_alphabet, poses)
        if pretrain_neural_model is not None:
            ensemble_model.neural_linear.weight.data.copy_(
                pretrain_neural_model.linear.weight.data)
        if pretrain_vsm_model is not None:
            ensemble_model.vsm_linear.weight.data.copy_(
                pretrain_vsm_model.linear.weight.data)
        ensemble_train_X = []
        ensemble_train_Y = []
        for doc in train_data:
            temp_X, temp_Y = generate_instances(doc, word_alphabet,
                                                dict_alphabet, dictionary,
                                                dictionary_reverse,
                                                isMeddra_dict)

            ensemble_train_X.extend(temp_X)
            ensemble_train_Y.extend(temp_Y)
        ensemble_train_loader = DataLoader(MyDataset(ensemble_train_X,
                                                     ensemble_train_Y),
                                           opt.batch_size,
                                           shuffle=True,
                                           collate_fn=my_collate)
        ensemble_optimizer = optim.Adam(ensemble_model.parameters(),
                                        lr=opt.lr,
                                        weight_decay=opt.l2)
        if opt.tune_wordemb == False:
            freeze_net(ensemble_model.word_embedding)
    else:

        # vsm
        logging.info("init vsm-based normer")
        poses = vsm.init_vector_for_dict(word_alphabet, dict_alphabet,
                                         dictionary, isMeddra_dict)
        # alphabet can share between vsm and neural since they don't change
        # but word_embedding cannot
        vsm_model = vsm.VsmNormer(word_alphabet, copy.deepcopy(word_embedding),
                                  embedding_dim, dict_alphabet, poses)
        vsm_train_X = []
        vsm_train_Y = []
        for doc in train_data:
            if isMeddra_dict:
                temp_X, temp_Y = vsm.generate_instances(
                    doc.entities, word_alphabet, dict_alphabet)
            else:
                temp_X, temp_Y = vsm.generate_instances_ehr(
                    doc.entities, word_alphabet, dict_alphabet,
                    dictionary_reverse)

            vsm_train_X.extend(temp_X)
            vsm_train_Y.extend(temp_Y)
        vsm_train_loader = DataLoader(vsm.MyDataset(vsm_train_X, vsm_train_Y),
                                      opt.batch_size,
                                      shuffle=True,
                                      collate_fn=vsm.my_collate)
        vsm_optimizer = optim.Adam(vsm_model.parameters(),
                                   lr=opt.lr,
                                   weight_decay=opt.l2)
        if opt.tune_wordemb == False:
            freeze_net(vsm_model.word_embedding)

        if d.config['norm_vsm_pretrain'] == '1':
            vsm.dict_pretrain(dictionary, dictionary_reverse, d, True,
                              vsm_optimizer, vsm_model)

        # neural
        logging.info("init neural-based normer")
        neural_model = norm_neural.NeuralNormer(word_alphabet,
                                                copy.deepcopy(word_embedding),
                                                embedding_dim, dict_alphabet)

        neural_train_X = []
        neural_train_Y = []
        for doc in train_data:
            if isMeddra_dict:
                temp_X, temp_Y = norm_neural.generate_instances(
                    doc.entities, word_alphabet, dict_alphabet)
            else:
                temp_X, temp_Y = norm_neural.generate_instances_ehr(
                    doc.entities, word_alphabet, dict_alphabet,
                    dictionary_reverse)

            neural_train_X.extend(temp_X)
            neural_train_Y.extend(temp_Y)
        neural_train_loader = DataLoader(norm_neural.MyDataset(
            neural_train_X, neural_train_Y),
                                         opt.batch_size,
                                         shuffle=True,
                                         collate_fn=norm_neural.my_collate)
        neural_optimizer = optim.Adam(neural_model.parameters(),
                                      lr=opt.lr,
                                      weight_decay=opt.l2)
        if opt.tune_wordemb == False:
            freeze_net(neural_model.word_embedding)

        if d.config['norm_neural_pretrain'] == '1':
            neural_model.dict_pretrain(dictionary, dictionary_reverse, d, True,
                                       neural_optimizer, neural_model)

    best_dev_f = -10
    best_dev_p = -10
    best_dev_r = -10

    bad_counter = 0

    logging.info("start training ...")
    for idx in range(opt.iter):
        epoch_start = time.time()

        if opt.ensemble == 'learn':

            ensemble_model.train()
            ensemble_train_iter = iter(ensemble_train_loader)
            ensemble_num_iter = len(ensemble_train_loader)

            for i in range(ensemble_num_iter):
                x, rules, lengths, y = next(ensemble_train_iter)

                y_pred = ensemble_model.forward(x, rules, lengths)

                l = ensemble_model.loss(y_pred, y)

                l.backward()

                if opt.gradient_clip > 0:
                    torch.nn.utils.clip_grad_norm_(ensemble_model.parameters(),
                                                   opt.gradient_clip)
                ensemble_optimizer.step()
                ensemble_model.zero_grad()

        else:

            vsm_model.train()
            vsm_train_iter = iter(vsm_train_loader)
            vsm_num_iter = len(vsm_train_loader)

            for i in range(vsm_num_iter):
                x, lengths, y = next(vsm_train_iter)

                l, _ = vsm_model.forward_train(x, lengths, y)

                l.backward()

                if opt.gradient_clip > 0:
                    torch.nn.utils.clip_grad_norm_(vsm_model.parameters(),
                                                   opt.gradient_clip)
                vsm_optimizer.step()
                vsm_model.zero_grad()

            neural_model.train()
            neural_train_iter = iter(neural_train_loader)
            neural_num_iter = len(neural_train_loader)

            for i in range(neural_num_iter):

                x, lengths, y = next(neural_train_iter)

                y_pred = neural_model.forward(x, lengths)

                l = neural_model.loss(y_pred, y)

                l.backward()

                if opt.gradient_clip > 0:
                    torch.nn.utils.clip_grad_norm_(neural_model.parameters(),
                                                   opt.gradient_clip)
                neural_optimizer.step()
                neural_model.zero_grad()

        epoch_finish = time.time()
        logging.info("epoch: %s training finished. Time: %.2fs" %
                     (idx, epoch_finish - epoch_start))

        if opt.dev_file:
            if opt.ensemble == 'learn':
                # logging.info("weight w1: %.4f, w2: %.4f, w3: %.4f" % (ensemble_model.w1.data.item(), ensemble_model.w2.data.item(), ensemble_model.w3.data.item()))
                p, r, f = norm_utils.evaluate(dev_data, dictionary,
                                              dictionary_reverse, None, None,
                                              ensemble_model, d, isMeddra_dict)
            else:
                p, r, f = norm_utils.evaluate(dev_data, dictionary,
                                              dictionary_reverse, vsm_model,
                                              neural_model, None, d,
                                              isMeddra_dict)
            logging.info("Dev: p: %.4f, r: %.4f, f: %.4f" % (p, r, f))
        else:
            f = best_dev_f

        if f > best_dev_f:
            logging.info("Exceed previous best f score on dev: %.4f" %
                         (best_dev_f))

            if opt.ensemble == 'learn':
                if fold_idx is None:
                    torch.save(ensemble_model,
                               os.path.join(opt.output, "ensemble.pkl"))
                else:
                    torch.save(
                        ensemble_model,
                        os.path.join(opt.output,
                                     "ensemble_{}.pkl".format(fold_idx + 1)))
            else:
                if fold_idx is None:
                    torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl"))
                    torch.save(neural_model,
                               os.path.join(opt.output, "norm_neural.pkl"))
                else:
                    torch.save(
                        vsm_model,
                        os.path.join(opt.output,
                                     "vsm_{}.pkl".format(fold_idx + 1)))
                    torch.save(
                        neural_model,
                        os.path.join(opt.output,
                                     "norm_neural_{}.pkl".format(fold_idx +
                                                                 1)))

            best_dev_f = f
            best_dev_p = p
            best_dev_r = r

            bad_counter = 0
        else:
            bad_counter += 1

        if len(opt.dev_file) != 0 and bad_counter >= opt.patience:
            logging.info('Early Stop!')
            break

    logging.info("train finished")

    if fold_idx is None:
        multi_sieve.finalize(True)
    else:
        if fold_idx == opt.cross_validation - 1:
            multi_sieve.finalize(True)
        else:
            multi_sieve.finalize(False)

    if len(opt.dev_file) == 0:
        if opt.ensemble == 'learn':
            torch.save(ensemble_model, os.path.join(opt.output,
                                                    "ensemble.pkl"))
        else:
            torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl"))
            torch.save(neural_model, os.path.join(opt.output,
                                                  "norm_neural.pkl"))

    return best_dev_p, best_dev_r, best_dev_f
Exemple #6
0
def train(data, ner_dir, re_dir):

    task_num = 2
    init_alpha = 1.0 / task_num
    if opt.hidden_num <= 1:
        step_alpha = 0
    else:
        step_alpha = (1.0 - init_alpha) / (opt.hidden_num - 1)

    ner_wordrep = WordRep(data, False, True, True, data.use_char)
    ner_hiddenlist = []
    for i in range(opt.hidden_num):
        if i == 0:
            input_size = data.word_emb_dim+data.HP_char_hidden_dim+data.feature_emb_dims[data.feature_name2id['[Cap]']]+ \
                         data.feature_emb_dims[data.feature_name2id['[POS]']]
            output_size = data.HP_hidden_dim
        else:
            input_size = data.HP_hidden_dim
            output_size = data.HP_hidden_dim

        temp = HiddenLayer(data, input_size, output_size)
        ner_hiddenlist.append(temp)

    seq_model = SeqModel(data)

    re_wordrep = WordRep(data, True, False, True, False)
    re_hiddenlist = []
    for i in range(opt.hidden_num):
        if i == 0:
            input_size = data.word_emb_dim + data.feature_emb_dims[data.feature_name2id['[POS]']]+\
                         2*data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']]
            output_size = data.HP_hidden_dim
        else:
            input_size = data.HP_hidden_dim
            output_size = data.HP_hidden_dim

        temp = HiddenLayer(data, input_size, output_size)
        re_hiddenlist.append(temp)

    classify_model = ClassifyModel(data)

    iter_parameter = itertools.chain(
        *map(list, [ner_wordrep.parameters(),
                    seq_model.parameters()] +
             [f.parameters() for f in ner_hiddenlist]))
    ner_optimizer = optim.Adam(iter_parameter,
                               lr=data.HP_lr,
                               weight_decay=data.HP_l2)
    iter_parameter = itertools.chain(
        *map(list, [re_wordrep.parameters(),
                    classify_model.parameters()] +
             [f.parameters() for f in re_hiddenlist]))
    re_optimizer = optim.Adam(iter_parameter,
                              lr=data.HP_lr,
                              weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(ner_wordrep.word_embedding)
        my_utils.freeze_net(re_wordrep.word_embedding)

    re_X_positive = []
    re_Y_positive = []
    re_X_negative = []
    re_Y_negative = []
    relation_vocab = data.re_feature_alphabets[
        data.re_feature_name2id['[RELATION]']]
    my_collate = my_utils.sorted_collate1
    for i in range(len(data.re_train_X)):
        x = data.re_train_X[i]
        y = data.re_train_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            re_X_positive.append(x)
            re_Y_positive.append(y)
        else:
            re_X_negative.append(x)
            re_Y_negative.append(y)

    re_test_loader = DataLoader(my_utils.RelationDataset(
        data.re_test_X, data.re_test_Y),
                                data.HP_batch_size,
                                shuffle=False,
                                collate_fn=my_collate)

    best_ner_score = -1
    best_re_score = -1

    for idx in range(data.HP_iteration):
        epoch_start = time.time()

        ner_wordrep.train()
        ner_wordrep.zero_grad()
        for hidden_layer in ner_hiddenlist:
            hidden_layer.train()
            hidden_layer.zero_grad()
        seq_model.train()
        seq_model.zero_grad()

        re_wordrep.train()
        re_wordrep.zero_grad()
        for hidden_layer in re_hiddenlist:
            hidden_layer.train()
            hidden_layer.zero_grad()
        classify_model.train()
        classify_model.zero_grad()

        batch_size = data.HP_batch_size

        random.shuffle(data.train_Ids)
        ner_train_num = len(data.train_Ids)
        ner_total_batch = ner_train_num // batch_size + 1

        re_train_loader, re_train_iter = makeRelationDataset(
            re_X_positive, re_Y_positive, re_X_negative, re_Y_negative,
            data.unk_ratio, True, my_collate, data.HP_batch_size)
        re_total_batch = len(re_train_loader)

        total_batch = max(ner_total_batch, re_total_batch)
        min_batch = min(ner_total_batch, re_total_batch)

        for batch_id in range(total_batch):

            if batch_id < min_batch:
                start = batch_id * batch_size
                end = (batch_id + 1) * batch_size
                if end > ner_train_num:
                    end = ner_train_num
                instance = data.train_Ids[start:end]
                ner_batch_word, ner_batch_features, ner_batch_wordlen, ner_batch_wordrecover, ner_batch_char, ner_batch_charlen, \
                ner_batch_charrecover, ner_batch_label, ner_mask, ner_batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                [re_batch_word, re_batch_features, re_batch_wordlen, re_batch_wordrecover, re_batch_char, re_batch_charlen,
                 re_batch_charrecover, re_position1_seq_tensor, re_position2_seq_tensor, re_e1_token, re_e1_length, re_e2_token, re_e2_length,
                 re_e1_type, re_e2_type, re_tok_num_betw, re_et_num], [re_targets, re_targets_permute] = \
                    my_utils.endless_get_next_batch_without_rebatch1(re_train_loader, re_train_iter)

                if ner_batch_word.size(0) != re_batch_word.size(0):
                    continue  # if batch size is not equal, we ignore such batch

                ner_word_rep = ner_wordrep.forward(
                    ner_batch_word, ner_batch_features, ner_batch_wordlen,
                    ner_batch_char, ner_batch_charlen, ner_batch_charrecover,
                    None, None)

                re_word_rep = re_wordrep.forward(
                    re_batch_word, re_batch_features, re_batch_wordlen,
                    re_batch_char, re_batch_charlen, re_batch_charrecover,
                    re_position1_seq_tensor, re_position2_seq_tensor)
                alpha = init_alpha
                ner_hidden = ner_word_rep
                re_hidden = re_word_rep
                for i in range(opt.hidden_num):
                    if alpha > 0.99:
                        use_attn = False

                        ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward(
                            ner_hidden, ner_batch_wordlen, use_attn)
                        re_lstm_out, re_att_out = re_hiddenlist[i].forward(
                            re_hidden, re_batch_wordlen, use_attn)

                        ner_hidden = ner_lstm_out
                        re_hidden = re_lstm_out
                    else:
                        use_attn = True

                        ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward(
                            ner_hidden, ner_batch_wordlen, use_attn)
                        re_lstm_out, re_att_out = re_hiddenlist[i].forward(
                            re_hidden, re_batch_wordlen, use_attn)

                        ner_hidden = alpha * ner_lstm_out + (
                            1 - alpha) * re_att_out.unsqueeze(1)
                        re_hidden = alpha * re_lstm_out + (
                            1 - alpha) * ner_att_out.unsqueeze(1)

                    alpha += step_alpha

                ner_loss, ner_tag_seq = seq_model.neg_log_likelihood_loss(
                    ner_hidden, ner_batch_label, ner_mask)
                re_loss, re_pred = classify_model.neg_log_likelihood_loss(
                    re_hidden, re_batch_wordlen, re_e1_token, re_e1_length,
                    re_e2_token, re_e2_length, re_e1_type, re_e2_type,
                    re_tok_num_betw, re_et_num, re_targets)

                ner_loss.backward(retain_graph=True)
                re_loss.backward()

                ner_optimizer.step()
                re_optimizer.step()

                ner_wordrep.zero_grad()
                for hidden_layer in ner_hiddenlist:
                    hidden_layer.zero_grad()
                seq_model.zero_grad()

                re_wordrep.zero_grad()
                for hidden_layer in re_hiddenlist:
                    hidden_layer.zero_grad()
                classify_model.zero_grad()

            else:

                if batch_id < ner_total_batch:
                    start = batch_id * batch_size
                    end = (batch_id + 1) * batch_size
                    if end > ner_train_num:
                        end = ner_train_num
                    instance = data.train_Ids[start:end]
                    batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \
                        batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                    ner_word_rep = ner_wordrep.forward(
                        batch_word, batch_features, batch_wordlen, batch_char,
                        batch_charlen, batch_charrecover, None, None)

                    ner_hidden = ner_word_rep
                    for i in range(opt.hidden_num):
                        ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward(
                            ner_hidden, batch_wordlen, False)
                        ner_hidden = ner_lstm_out

                    loss, tag_seq = seq_model.neg_log_likelihood_loss(
                        ner_hidden, batch_label, mask)

                    loss.backward()
                    ner_optimizer.step()
                    ner_wordrep.zero_grad()
                    for hidden_layer in ner_hiddenlist:
                        hidden_layer.zero_grad()
                    seq_model.zero_grad()

                if batch_id < re_total_batch:
                    [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
                     batch_char, batch_charlen, batch_charrecover, \
                     position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
                     tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(
                        re_train_loader, re_train_iter)

                    re_word_rep = re_wordrep.forward(
                        batch_word, batch_features, batch_wordlen, batch_char,
                        batch_charlen, batch_charrecover, position1_seq_tensor,
                        position2_seq_tensor)

                    re_hidden = re_word_rep
                    for i in range(opt.hidden_num):
                        re_lstm_out, re_att_out = re_hiddenlist[i].forward(
                            re_hidden, batch_wordlen, False)
                        re_hidden = re_lstm_out

                    loss, pred = classify_model.neg_log_likelihood_loss(
                        re_hidden, batch_wordlen, e1_token, e1_length,
                        e2_token, e2_length, e1_type, e2_type, tok_num_betw,
                        et_num, targets)
                    loss.backward()
                    re_optimizer.step()
                    re_wordrep.zero_grad()
                    for hidden_layer in re_hiddenlist:
                        hidden_layer.zero_grad()
                    classify_model.zero_grad()

        epoch_finish = time.time()
        print("epoch: %s training finished. Time: %.2fs" %
              (idx, epoch_finish - epoch_start))

        ner_score = ner_evaluate(data, ner_wordrep, ner_hiddenlist, seq_model,
                                 "test")
        print("ner evaluate: f: %.4f" % (ner_score))

        re_score = re_evaluate(re_wordrep, re_hiddenlist, classify_model,
                               re_test_loader)
        print("re evaluate: f: %.4f" % (re_score))

        if ner_score + re_score > best_ner_score + best_re_score:
            print("new best score: ner: %.4f , re: %.4f" %
                  (ner_score, re_score))
            best_ner_score = ner_score
            best_re_score = re_score

            torch.save(ner_wordrep.state_dict(),
                       os.path.join(ner_dir, 'wordrep.pkl'))
            for i, hidden_layer in enumerate(ner_hiddenlist):
                torch.save(hidden_layer.state_dict(),
                           os.path.join(ner_dir, 'hidden_{}.pkl'.format(i)))
            torch.save(seq_model.state_dict(),
                       os.path.join(ner_dir, 'model.pkl'))

            torch.save(re_wordrep.state_dict(),
                       os.path.join(re_dir, 'wordrep.pkl'))
            for i, hidden_layer in enumerate(re_hiddenlist):
                torch.save(hidden_layer.state_dict(),
                           os.path.join(re_dir, 'hidden_{}.pkl'.format(i)))
            torch.save(classify_model.state_dict(),
                       os.path.join(re_dir, 'model.pkl'))
Exemple #7
0
def pipeline(data, ner_dir, re_dir):

    seq_model = SeqModel(data)
    seq_wordseq = WordSequence(data, False, True, True, data.use_char)

    classify_wordseq = WordSequence(data, True, False, True, False)
    classify_model = ClassifyModel(data)
    if torch.cuda.is_available():
        classify_model = classify_model.cuda(data.HP_gpu)

    iter_parameter = itertools.chain(
        *map(list, [seq_wordseq.parameters(),
                    seq_model.parameters()]))
    seq_optimizer = optim.Adam(iter_parameter,
                               lr=opt.ner_lr,
                               weight_decay=data.HP_l2)
    iter_parameter = itertools.chain(*map(
        list, [classify_wordseq.parameters(),
               classify_model.parameters()]))
    classify_optimizer = optim.Adam(iter_parameter,
                                    lr=opt.re_lr,
                                    weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(seq_wordseq.wordrep.word_embedding)
        my_utils.freeze_net(classify_wordseq.wordrep.word_embedding)

    re_X_positive = []
    re_Y_positive = []
    re_X_negative = []
    re_Y_negative = []
    relation_vocab = data.re_feature_alphabets[
        data.re_feature_name2id['[RELATION]']]
    my_collate = my_utils.sorted_collate1
    for i in range(len(data.re_train_X)):
        x = data.re_train_X[i]
        y = data.re_train_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            re_X_positive.append(x)
            re_Y_positive.append(y)
        else:
            re_X_negative.append(x)
            re_Y_negative.append(y)

    re_test_loader = DataLoader(my_utils.RelationDataset(
        data.re_test_X, data.re_test_Y),
                                data.HP_batch_size,
                                shuffle=False,
                                collate_fn=my_collate)

    best_ner_score = -1
    best_re_score = -1

    for idx in range(data.HP_iteration):
        epoch_start = time.time()

        seq_wordseq.train()
        seq_wordseq.zero_grad()
        seq_model.train()
        seq_model.zero_grad()

        classify_wordseq.train()
        classify_wordseq.zero_grad()
        classify_model.train()
        classify_model.zero_grad()

        batch_size = data.HP_batch_size

        random.shuffle(data.train_Ids)
        ner_train_num = len(data.train_Ids)
        ner_total_batch = ner_train_num // batch_size + 1

        re_train_loader, re_train_iter = makeRelationDataset(
            re_X_positive, re_Y_positive, re_X_negative, re_Y_negative,
            data.unk_ratio, True, my_collate, data.HP_batch_size)
        re_total_batch = len(re_train_loader)

        total_batch = max(ner_total_batch, re_total_batch)
        min_batch = min(ner_total_batch, re_total_batch)

        for batch_id in range(total_batch):

            if batch_id < ner_total_batch:
                start = batch_id * batch_size
                end = (batch_id + 1) * batch_size
                if end > ner_train_num:
                    end = ner_train_num
                instance = data.train_Ids[start:end]
                batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \
                    batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                hidden = seq_wordseq.forward(batch_word, batch_features,
                                             batch_wordlen, batch_char,
                                             batch_charlen, batch_charrecover,
                                             None, None)
                hidden_adv = None
                loss, tag_seq = seq_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_label, mask)
                loss.backward()
                seq_optimizer.step()
                seq_wordseq.zero_grad()
                seq_model.zero_grad()

            if batch_id < re_total_batch:
                [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
                 batch_char, batch_charlen, batch_charrecover, \
                 position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
                 tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(
                    re_train_loader, re_train_iter)

                hidden = classify_wordseq.forward(batch_word, batch_features,
                                                  batch_wordlen, batch_char,
                                                  batch_charlen,
                                                  batch_charrecover,
                                                  position1_seq_tensor,
                                                  position2_seq_tensor)
                hidden_adv = None
                loss, pred = classify_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_wordlen, e1_token, e1_length,
                    e2_token, e2_length, e1_type, e2_type, tok_num_betw,
                    et_num, targets)
                loss.backward()
                classify_optimizer.step()
                classify_wordseq.zero_grad()
                classify_model.zero_grad()

        epoch_finish = time.time()
        print("epoch: %s training finished. Time: %.2fs" %
              (idx, epoch_finish - epoch_start))

        # _, _, _, _, f, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test")
        ner_score = ner.evaluate1(data, seq_wordseq, seq_model, "test")
        print("ner evaluate: f: %.4f" % (ner_score))

        re_score = relation_extraction.evaluate(classify_wordseq,
                                                classify_model, re_test_loader)
        print("re evaluate: f: %.4f" % (re_score))

        if ner_score + re_score > best_ner_score + best_re_score:
            print("new best score: ner: %.4f , re: %.4f" %
                  (ner_score, re_score))
            best_ner_score = ner_score
            best_re_score = re_score

            torch.save(seq_wordseq.state_dict(),
                       os.path.join(ner_dir, 'wordseq.pkl'))
            torch.save(seq_model.state_dict(),
                       os.path.join(ner_dir, 'model.pkl'))
            torch.save(classify_wordseq.state_dict(),
                       os.path.join(re_dir, 'wordseq.pkl'))
            torch.save(classify_model.state_dict(),
                       os.path.join(re_dir, 'model.pkl'))
Exemple #8
0
def train(data):
    print "Training model..."
    data.show_data_summary()
    save_data_name = data.model_dir + ".dset"
    data.save(save_data_name)
    model = SeqModel(data)
    loss_function = nn.NLLLoss()
    if data.optimizer.lower() == "sgd":
        optimizer = optim.SGD(model.parameters(),
                              lr=data.HP_lr,
                              momentum=data.HP_momentum,
                              weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "adagrad":
        optimizer = optim.Adagrad(model.parameters(),
                                  lr=data.HP_lr,
                                  weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "adadelta":
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=data.HP_lr,
                                   weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "rmsprop":
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=data.HP_lr,
                                  weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "adam":
        optimizer = optim.Adam(model.parameters(),
                               lr=data.HP_lr,
                               weight_decay=data.HP_l2)
    else:
        print("Optimizer illegal: %s" % (data.optimizer))
        exit(0)
    best_dev = -10

    if opt.tune_wordemb == False:
        my_utils.freeze_net(model.word_hidden.wordrep.word_embedding)

    # data.HP_iteration = 1
    ## start training
    for idx in range(data.HP_iteration):
        epoch_start = time.time()
        temp_start = epoch_start
        print("Epoch: %s/%s" % (idx, data.HP_iteration))
        if data.optimizer == "SGD":
            optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr)
        instance_count = 0
        sample_id = 0
        sample_loss = 0
        total_loss = 0
        right_token = 0
        whole_token = 0
        random.shuffle(data.train_Ids)
        ## set model in train model
        model.train()
        model.zero_grad()
        batch_size = data.HP_batch_size
        batch_id = 0
        train_num = len(data.train_Ids)
        total_batch = train_num // batch_size + 1
        for batch_id in range(total_batch):
            start = batch_id * batch_size
            end = (batch_id + 1) * batch_size
            if end > train_num:
                end = train_num
            instance = data.train_Ids[start:end]
            if not instance:
                continue
            batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(
                instance, data.HP_gpu)
            instance_count += 1
            loss, tag_seq = model.neg_log_likelihood_loss(
                batch_word, batch_features, batch_wordlen, batch_char,
                batch_charlen, batch_charrecover, batch_label, mask)
            right, whole = predict_check(tag_seq, batch_label, mask)
            right_token += right
            whole_token += whole
            sample_loss += loss.data.item()
            total_loss += loss.data.item()
            if end % 500 == 0:
                temp_time = time.time()
                temp_cost = temp_time - temp_start
                temp_start = temp_time
                print(
                    "     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f"
                    % (end, temp_cost, sample_loss, right_token, whole_token,
                       (right_token + 0.) / whole_token))
                if sample_loss > 1e8 or str(sample_loss) == "nan":
                    print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...."
                    exit(0)
                sys.stdout.flush()
                sample_loss = 0
            loss.backward()
            optimizer.step()
            model.zero_grad()
        temp_time = time.time()
        temp_cost = temp_time - temp_start
        print("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
              (end, temp_cost, sample_loss, right_token, whole_token,
               (right_token + 0.) / whole_token))

        epoch_finish = time.time()
        epoch_cost = epoch_finish - epoch_start
        print(
            "Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s,  total loss: %s"
            % (idx, epoch_cost, train_num / epoch_cost, total_loss))
        print "totalloss:", total_loss
        if total_loss > 1e8 or str(total_loss) == "nan":
            print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...."
            exit(0)
        # continue
        speed, acc, p, r, f, _, _ = evaluate(data, model, "dev")
        dev_finish = time.time()
        dev_cost = dev_finish - epoch_finish

        if data.seg:
            current_score = f
            print(
                "Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f"
                % (dev_cost, speed, acc, p, r, f))
        else:
            current_score = acc
            print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" %
                  (dev_cost, speed, acc))

        if current_score > best_dev:
            if data.seg:
                print "Exceed previous best f score:", best_dev
            else:
                print "Exceed previous best acc score:", best_dev

            model_name = data.model_dir + ".model"
            torch.save(model.state_dict(), model_name)
            best_dev = current_score

        gc.collect()
Exemple #9
0
def train(data, model_file):
    print "Training model..."

    model = SeqModel(data)
    wordseq = WordSequence(data, False, True, data.use_char)
    if opt.self_adv == 'grad':
        wordseq_adv = WordSequence(data, False, True, data.use_char)
    elif opt.self_adv == 'label':
        wordseq_adv = WordSequence(data, False, True, data.use_char)
        model_adv = SeqModel(data)
    else:
        wordseq_adv = None

    if data.optimizer.lower() == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum, weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "adagrad":
        optimizer = optim.Adagrad(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "adadelta":
        optimizer = optim.Adadelta(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "rmsprop":
        optimizer = optim.RMSprop(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2)
    elif data.optimizer.lower() == "adam":
        if opt.self_adv == 'grad':
            iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), wordseq_adv.parameters(), model.parameters()]))
            optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)
        elif opt.self_adv == 'label':
            iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()]))
            optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)
            iter_parameter = itertools.chain(*map(list, [wordseq_adv.parameters(), model_adv.parameters()]))
            optimizer_adv = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)

        else:
            iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()]))
            optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)

    else:
        print("Optimizer illegal: %s" % (data.optimizer))
        exit(0)
    best_dev = -10

    if data.tune_wordemb == False:
        my_utils.freeze_net(wordseq.wordrep.word_embedding)
        if opt.self_adv != 'no':
            my_utils.freeze_net(wordseq_adv.wordrep.word_embedding)


    # data.HP_iteration = 1
    ## start training
    for idx in range(data.HP_iteration):
        epoch_start = time.time()
        temp_start = epoch_start
        print("epoch: %s/%s" % (idx, data.HP_iteration))
        if data.optimizer == "SGD":
            optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr)
        instance_count = 0
        sample_id = 0
        sample_loss = 0
        total_loss = 0
        right_token = 0
        whole_token = 0
        random.shuffle(data.train_Ids)
        ## set model in train model
        wordseq.train()
        wordseq.zero_grad()
        if opt.self_adv == 'grad':
            wordseq_adv.train()
            wordseq_adv.zero_grad()
        elif opt.self_adv == 'label':
            wordseq_adv.train()
            wordseq_adv.zero_grad()
            model_adv.train()
            model_adv.zero_grad()
        model.train()
        model.zero_grad()
        batch_size = data.HP_batch_size
        batch_id = 0
        train_num = len(data.train_Ids)
        total_batch = train_num // batch_size + 1
        for batch_id in range(total_batch):
            start = batch_id * batch_size
            end = (batch_id + 1) * batch_size
            if end > train_num:
                end = train_num
            instance = data.train_Ids[start:end]
            if not instance:
                continue
            batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask,\
                batch_permute_label = batchify_with_label(instance, data.HP_gpu)
            instance_count += 1

            if opt.self_adv == 'grad':
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,batch_charrecover, None, None)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None)
                loss, tag_seq = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_label, mask)
                loss.backward()
                my_utils.reverse_grad(wordseq_adv)
                optimizer.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model.zero_grad()

            elif opt.self_adv == 'label' :
                wordseq.unfreeze_net()
                wordseq_adv.freeze_net()
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,batch_charrecover, None, None)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None)
                loss, tag_seq = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_label, mask)
                loss.backward()
                optimizer.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model.zero_grad()

                wordseq.freeze_net()
                wordseq_adv.unfreeze_net()
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,batch_charrecover, None, None)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None)
                loss_adv, _ = model_adv.neg_log_likelihood_loss(hidden, hidden_adv, batch_permute_label, mask)
                loss_adv.backward()
                optimizer_adv.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model_adv.zero_grad()

            else:
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None)
                hidden_adv = None
                loss, tag_seq = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_label, mask)
                loss.backward()
                optimizer.step()
                wordseq.zero_grad()
                model.zero_grad()


            # right, whole = predict_check(tag_seq, batch_label, mask)
            # right_token += right
            # whole_token += whole
            sample_loss += loss.data.item()
            total_loss += loss.data.item()
            if end % 500 == 0:
                # temp_time = time.time()
                # temp_cost = temp_time - temp_start
                # temp_start = temp_time
                # print("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (
                # end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token))
                if sample_loss > 1e8 or str(sample_loss) == "nan":
                    print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...."
                    exit(0)
                sys.stdout.flush()
                sample_loss = 0

        # temp_time = time.time()
        # temp_cost = temp_time - temp_start
        # print("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (
        # end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token))

        epoch_finish = time.time()
        epoch_cost = epoch_finish - epoch_start
        print("epoch: %s training finished. Time: %.2fs, speed: %.2fst/s,  total loss: %s" % (
        idx, epoch_cost, train_num / epoch_cost, total_loss))
        print "totalloss:", total_loss
        if total_loss > 1e8 or str(total_loss) == "nan":
            print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...."
            exit(0)
        # continue
        speed, acc, p, r, f, _, _ = evaluate(data, wordseq, model, "test")

        dev_finish = time.time()
        dev_cost = dev_finish - epoch_finish

        if data.seg:
            current_score = f
            print("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (
            dev_cost, speed, acc, p, r, f))
        else:
            current_score = acc
            print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc))

        if current_score > best_dev:
            if data.seg:
                print "Exceed previous best f score:", best_dev
            else:
                print "Exceed previous best acc score:", best_dev

            torch.save(wordseq.state_dict(), os.path.join(model_file, 'wordseq.pkl'))
            if opt.self_adv == 'grad':
                torch.save(wordseq_adv.state_dict(), os.path.join(model_file, 'wordseq_adv.pkl'))
            elif opt.self_adv == 'label':
                torch.save(wordseq_adv.state_dict(), os.path.join(model_file, 'wordseq_adv.pkl'))
                torch.save(model_adv.state_dict(), os.path.join(model_file, 'model_adv.pkl'))
            model_name = os.path.join(model_file, 'model.pkl')
            torch.save(model.state_dict(), model_name)
            best_dev = current_score

        gc.collect()
Exemple #10
0
        # batch size always 1
        train_loader = DataLoader(MyDataset(instances_train),
                                  1,
                                  shuffle=True,
                                  collate_fn=my_collate)
        test_loader = DataLoader(MyDataset(instances_test),
                                 1,
                                 shuffle=False,
                                 collate_fn=my_collate)

        optimizer = optim.Adam(vsm_model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.l2)

        if opt.tune_wordemb == False:
            freeze_net(vsm_model.word_embedding)

        best_acc = -10

        bad_counter = 0

        logging.info("start training ...")

        for idx in range(opt.iter):
            epoch_start = time.time()

            vsm_model.train()

            train_iter = iter(train_loader)
            num_iter = len(train_loader)
Exemple #11
0
def train(other_dir):

    logging.info("loading ... vocab")
    word_vocab = pickle.load(open(os.path.join(opt.pretrain, 'word_vocab.pkl'), 'rb'))
    postag_vocab = pickle.load(open(os.path.join(opt.pretrain, 'postag_vocab.pkl'), 'rb'))
    relation_vocab = pickle.load(open(os.path.join(opt.pretrain, 'relation_vocab.pkl'), 'rb'))
    entity_type_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_type_vocab.pkl'), 'rb'))
    entity_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_vocab.pkl'), 'rb'))
    position_vocab1 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab1.pkl'), 'rb'))
    position_vocab2 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab2.pkl'), 'rb'))
    tok_num_betw_vocab = pickle.load(open(os.path.join(opt.pretrain, 'tok_num_betw_vocab.pkl'), 'rb'))
    et_num_vocab = pickle.load(open(os.path.join(opt.pretrain, 'et_num_vocab.pkl'), 'rb'))


    my_collate = my_utils.sorted_collate if opt.model == 'lstm' else my_utils.unsorted_collate

    # test only on the main domain
    test_X = pickle.load(open(os.path.join(opt.pretrain, 'test_X.pkl'), 'rb'))
    test_Y = pickle.load(open(os.path.join(opt.pretrain, 'test_Y.pkl'), 'rb'))
    test_Other = pickle.load(open(os.path.join(opt.pretrain, 'test_Other.pkl'), 'rb'))
    logging.info("total test instance {}".format(len(test_Y)))
    test_loader = DataLoader(my_utils.RelationDataset(test_X, test_Y),
                              opt.batch_size, shuffle=False, collate_fn=my_collate) # drop_last=True

    # train on the main as well as other domains
    domains = ['main']
    train_loaders, train_iters, unk_loaders, unk_iters = {}, {}, {}, {}
    train_X = pickle.load(open(os.path.join(opt.pretrain, 'train_X.pkl'), 'rb'))
    train_Y = pickle.load(open(os.path.join(opt.pretrain, 'train_Y.pkl'), 'rb'))
    logging.info("total training instance {}".format(len(train_Y)))
    train_loaders['main'], train_iters['main'] = makeDatasetWithoutUnknown(train_X, train_Y, relation_vocab, True, my_collate)
    unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio)

    for other in other_dir:
        domains.append(other)
        other_X = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_X.pkl'.format(other)), 'rb'))
        other_Y = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_Y.pkl'.format(other)), 'rb'))
        logging.info("other {} instance {}".format(other, len(other_Y)))
        train_loaders[other], train_iters[other] = makeDatasetWithoutUnknown(other_X, other_Y, relation_vocab, True, my_collate)
        unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio)

    opt.domains = domains


    F_s = None
    F_d = {}
    if opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2,
                                                 opt.F_layers, opt.shared_hidden_size, opt.dropout)
        for domain in opt.domains:
            F_d[domain] = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2,
                                                 opt.F_layers, opt.shared_hidden_size, opt.dropout)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2,
                                                opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes, opt.dropout)
        for domain in opt.domains:
            F_d[domain] = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2,
                                                opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes, opt.dropout)
    else:
        raise RuntimeError('Unknown feature extractor {}'.format(opt.model))

    if opt.model_high == 'capsule':
        C = capsule.CapsuleNet(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab,
                                         et_num_vocab)
    elif opt.model_high == 'mlp':
        C = baseline.MLP(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab,
                                         et_num_vocab)
    else:
        raise RuntimeError('Unknown model {}'.format(opt.model_high))

    if opt.adv:
        D = DomainClassifier(1, opt.shared_hidden_size, opt.shared_hidden_size,
                             len(opt.domains), opt.loss, opt.dropout, True)

    if torch.cuda.is_available():
        F_s, C = F_s.cuda(opt.gpu), C.cuda(opt.gpu)
        for f_d in F_d.values():
            f_d = f_d.cuda(opt.gpu)
        if opt.adv:
            D = D.cuda(opt.gpu)

    iter_parameter = itertools.chain(
        *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()]))
    optimizer = optim.Adam(iter_parameter, lr=opt.learning_rate)
    if opt.adv:
        optimizerD = optim.Adam(D.parameters(), lr=opt.learning_rate)


    best_acc = 0.0
    logging.info("start training ...")
    for epoch in range(opt.max_epoch):

        F_s.train()
        C.train()
        for f in F_d.values():
            f.train()
        if opt.adv:
            D.train()

        # domain accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        if opt.adv:
            d_correct, d_total = 0, 0

        # conceptually view 1 epoch as 1 epoch of the main domain
        num_iter = len(train_loaders['main'])

        for i in tqdm(range(num_iter)):

            if opt.adv:
                # D iterations
                my_utils.freeze_net(F_s)
                map(my_utils.freeze_net, F_d.values())
                my_utils.freeze_net(C)
                my_utils.unfreeze_net(D)

                if opt.tune_wordemb == False:
                    my_utils.freeze_net(F_s.word_emb)
                    for f_d in F_d.values():
                        my_utils.freeze_net(f_d.word_emb)
                # WGAN n_critic trick since D trains slower
                n_critic = opt.n_critic
                if opt.wgan_trick:
                    if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0):
                        n_critic = 100

                for _ in range(n_critic):
                    D.zero_grad()

                    # train on both labeled and unlabeled domains
                    for domain in opt.domains:
                        # targets not used
                        x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain],
                                                                                       train_iters[domain])
                        d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1]))
                        shared_feat = F_s(x2, x1)
                        d_outputs = D(shared_feat)
                        # D accuracy
                        _, pred = torch.max(d_outputs, 1)
                        d_total += len(x2[1])
                        if opt.loss.lower() == 'l2':
                            _, tgt_indices = torch.max(d_targets, 1)
                            d_correct += (pred == tgt_indices).sum().data.item()
                            l_d = functional.mse_loss(d_outputs, d_targets)
                            l_d.backward()
                        else:
                            d_correct += (pred == d_targets).sum().data.item()
                            l_d = functional.nll_loss(d_outputs, d_targets)
                            l_d.backward()

                    optimizerD.step()

            # F&C iteration
            my_utils.unfreeze_net(F_s)
            map(my_utils.unfreeze_net, F_d.values())
            my_utils.unfreeze_net(C)
            if opt.adv:
                my_utils.freeze_net(D)
            if opt.tune_wordemb == False:
                my_utils.freeze_net(F_s.word_emb)
                for f_d in F_d.values():
                    my_utils.freeze_net(f_d.word_emb)

            F_s.zero_grad()
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()

            for domain in opt.domains:

                x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain])
                shared_feat = F_s(x2, x1)
                domain_feat = F_d[domain](x2, x1)
                features = torch.cat((shared_feat, domain_feat), dim=1)

                if opt.model_high == 'capsule':
                    c_outputs, x_recon = C.forward(features, x2, x1, targets)
                    l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon)
                elif opt.model_high == 'mlp':
                    c_outputs = C.forward(features, x2, x1)
                    l_c = C.loss(targets, c_outputs)
                else:
                    raise RuntimeError('Unknown model {}'.format(opt.model_high))

                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().data.item()

                # training with unknown
                x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(unk_loaders[domain], unk_iters[domain])
                shared_feat = F_s(x2, x1)
                domain_feat = F_d[domain](x2, x1)
                features = torch.cat((shared_feat, domain_feat), dim=1)

                if opt.model_high == 'capsule':
                    c_outputs, x_recon = C.forward(features, x2, x1, targets)
                    l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon)
                elif opt.model_high == 'mlp':
                    c_outputs = C.forward(features, x2, x1)
                    l_c = C.loss(targets, c_outputs)
                else:
                    raise RuntimeError('Unknown model {}'.format(opt.model_high))

                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().data.item()

            if opt.adv:
                # update F with D gradients on all domains
                for domain in opt.domains:
                    x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain],
                                                                                   train_iters[domain])
                    shared_feat = F_s(x2, x1)
                    d_outputs = D(shared_feat)
                    if opt.loss.lower() == 'gr':
                        d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1]))
                        l_d = functional.nll_loss(d_outputs, d_targets)
                        if opt.lambd > 0:
                            l_d *= -opt.lambd
                    elif opt.loss.lower() == 'bs':
                        d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1]))
                        l_d = functional.kl_div(d_outputs, d_targets, size_average=False)
                        if opt.lambd > 0:
                            l_d *= opt.lambd
                    elif opt.loss.lower() == 'l2':
                        d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1]))
                        l_d = functional.mse_loss(d_outputs, d_targets)
                        if opt.lambd > 0:
                            l_d *= opt.lambd
                    l_d.backward()

            torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip)
            optimizer.step()

        # regenerate unknown dataset after one epoch
        unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio)
        for other in other_dir:
            unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio)

        logging.info('epoch {} end'.format(epoch))
        if opt.adv and d_total > 0:
            logging.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
        logging.info('Training accuracy:')
        logging.info('\t'.join(opt.domains))
        logging.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.domains]))

        test_accuracy = evaluate(F_s, F_d['main'], C, test_loader, test_Other)
        logging.info('Test Accuracy: {}%'.format(test_accuracy))

        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(F_s.state_dict(), '{}/F_s.pth'.format(opt.output))
            for d in opt.domains:
                torch.save(F_d[d].state_dict(), '{}/F_d_{}.pth'.format(opt.output, d))
            torch.save(C.state_dict(), '{}/C.pth'.format(opt.output))
            if opt.adv:
                torch.save(D.state_dict(), '{}/D.pth'.format(opt.output))
            pickle.dump(test_Other, open(os.path.join(opt.output, 'results.pkl'), "wb"), True)
            logging.info('New best accuracy: {}'.format(best_acc))


    logging.info("training completed")