def makeRelationDataset(re_X_positive, re_Y_positive, re_X_negative,
                        re_Y_negative, ratio, b_shuffle, my_collate,
                        batch_size):

    a = list(range(len(re_X_negative)))
    random.shuffle(a)
    indices = a[:int(len(re_X_negative) * ratio)]

    temp_X = []
    temp_Y = []
    for i in range(len(re_X_positive)):
        temp_X.append(re_X_positive[i])
        temp_Y.append(re_Y_positive[i])
    for i in range(len(indices)):
        temp_X.append(re_X_negative[indices[i]])
        temp_Y.append(re_Y_negative[indices[i]])

    data_set = my_utils.RelationDataset(temp_X, temp_Y)

    data_loader = DataLoader(data_set,
                             batch_size,
                             shuffle=b_shuffle,
                             collate_fn=my_collate)
    it = iter(data_loader)
    return data_loader, it
def makeDatasetWithoutUnknown(test_X, test_Y, relation_vocab, b_shuffle, my_collate, batch_size):
    test_X_remove_unk = []
    test_Y_remove_unk = []
    for i in range(len(test_X)):
        x = test_X[i]
        y = test_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            test_X_remove_unk.append(x)
            test_Y_remove_unk.append(y)

    test_set = my_utils.RelationDataset(test_X_remove_unk, test_Y_remove_unk)
    test_loader = DataLoader(test_set, batch_size, shuffle=b_shuffle, collate_fn=my_collate)
    it = iter(test_loader)
    logging.info("instance after removing unknown, {}".format(len(test_Y_remove_unk)))
    return test_loader, it
def makeDatasetUnknown(test_X, test_Y, relation_vocab, my_collate, ratio, batch_size):
    test_X_remove_unk = []
    test_Y_remove_unk = []
    for i in range(len(test_X)):
        x = test_X[i]
        y = test_Y[i]

        if y == relation_vocab.get_index("</unk>"):
            test_X_remove_unk.append(x)
            test_Y_remove_unk.append(y)

    test_set = my_utils.RelationDataset(test_X_remove_unk, test_Y_remove_unk)

    test_loader = DataLoader(test_set, batch_size, shuffle=False, sampler=randomSampler(test_Y_remove_unk, ratio), collate_fn=my_collate)
    it = iter(test_loader)

    return test_loader, it
Exemple #4
0
def makeDatasetForEachClass(train_X, train_Y, relation_vocab, my_collate):
    train_X_classified = {}
    train_Y_classified = {}
    for i in range(len(train_X)):
        x = train_X[i]
        y = train_Y[i]
        class_name = relation_vocab.lookup_id2str(y)
        if class_name in train_X_classified.keys():
            train_X_classified[class_name].append(x)
            train_Y_classified[class_name].append(y)
        else:
            train_X_classified[class_name] = [x]
            train_Y_classified[class_name] = [y]

    train_sets = []  # each class corresponds to a set, loader, sample
    train_loaders = []
    train_samples = []
    train_numbers = []
    train_iters = []
    for class_name in train_Y_classified:
        x = train_X_classified[class_name]
        y = train_Y_classified[class_name]
        train_numbers.append((class_name, len(y)))

        train_set = my_utils.RelationDataset(x, y)
        train_sets.append(train_set)
        train_sampler = torch.utils.data.sampler.RandomSampler(train_set)
        train_samples.append(train_sampler)
        train_loader = DataLoader(train_set,
                                  1,
                                  shuffle=False,
                                  sampler=train_sampler,
                                  collate_fn=my_collate)
        train_loaders.append(train_loader)
        train_iter = iter(train_loader)
        train_iters.append(train_iter)

    return train_loaders, train_iters, train_numbers
def train(data, dir):

    my_collate = my_utils.sorted_collate1

    train_loader, train_iter = makeDatasetWithoutUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], True, my_collate, data.HP_batch_size)
    num_iter = len(train_loader)
    unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], my_collate, data.unk_ratio, data.HP_batch_size)

    test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y),
                              data.HP_batch_size, shuffle=False, collate_fn=my_collate)

    wordseq = WordSequence(data, True, False, False)
    model = ClassifyModel(data)
    if torch.cuda.is_available():
        model = model.cuda(data.HP_gpu)

    if opt.self_adv == 'grad':
        wordseq_adv = WordSequence(data, True, False, False)
    elif opt.self_adv == 'label':
        wordseq_adv = WordSequence(data, True, False, False)
        model_adv = ClassifyModel(data)
        if torch.cuda.is_available():
            model_adv = model_adv.cuda(data.HP_gpu)
    else:
        wordseq_adv = None

    if opt.self_adv == 'grad':
        iter_parameter = itertools.chain(
            *map(list, [wordseq.parameters(), wordseq_adv.parameters(), model.parameters()]))
        optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)
    elif opt.self_adv == 'label':
        iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()]))
        optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)
        iter_parameter = itertools.chain(*map(list, [wordseq_adv.parameters(), model_adv.parameters()]))
        optimizer_adv = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)

    else:
        iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()]))
        optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(wordseq.wordrep.word_embedding)
        if opt.self_adv != 'no':
            my_utils.freeze_net(wordseq_adv.wordrep.word_embedding)

    best_acc = 0.0
    logging.info("start training ...")
    for epoch in range(data.max_epoch):
        wordseq.train()
        wordseq.zero_grad()
        model.train()
        model.zero_grad()
        if opt.self_adv == 'grad':
            wordseq_adv.train()
            wordseq_adv.zero_grad()
        elif opt.self_adv == 'label':
            wordseq_adv.train()
            wordseq_adv.zero_grad()
            model_adv.train()
            model_adv.zero_grad()
        correct, total = 0, 0

        for i in range(num_iter):
            [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
            batch_char, batch_charlen, batch_charrecover, \
            position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
            tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(train_loader, train_iter)


            if opt.self_adv == 'grad':
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets)
                loss.backward()
                my_utils.reverse_grad(wordseq_adv)
                optimizer.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model.zero_grad()

            elif opt.self_adv == 'label' :
                wordseq.unfreeze_net()
                wordseq_adv.freeze_net()
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets)
                loss.backward()
                optimizer.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model.zero_grad()

                wordseq.freeze_net()
                wordseq_adv.unfreeze_net()
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                loss_adv, _ = model_adv.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets_permute)
                loss_adv.backward()
                optimizer_adv.step()
                wordseq.zero_grad()
                wordseq_adv.zero_grad()
                model_adv.zero_grad()

            else:
                hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                         batch_charrecover, position1_seq_tensor, position2_seq_tensor)
                hidden_adv = None
                loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                           e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                           tok_num_betw, et_num, targets)
                loss.backward()
                optimizer.step()
                wordseq.zero_grad()
                model.zero_grad()



            total += targets.size(0)
            correct += (pred == targets).sum().item()


            [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
            batch_char, batch_charlen, batch_charrecover, \
            position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
            tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(unk_loader, unk_iter)

            hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,
                                     batch_charrecover, position1_seq_tensor, position2_seq_tensor)
            hidden_adv = None
            loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen,
                                                       e1_token, e1_length, e2_token, e2_length, e1_type, e2_type,
                                                       tok_num_betw, et_num, targets)
            loss.backward()
            optimizer.step()
            wordseq.zero_grad()
            model.zero_grad()


        unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y,
                                                  data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']],
                                                  my_collate, data.unk_ratio, data.HP_batch_size)

        logging.info('epoch {} end'.format(epoch))
        logging.info('Train Accuracy: {}%'.format(100.0 * correct / total))

        test_accuracy = evaluate(wordseq, model, test_loader)
        # test_accuracy = evaluate(m, test_loader)
        logging.info('Test Accuracy: {}%'.format(test_accuracy))

        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(wordseq.state_dict(), os.path.join(dir, 'wordseq.pkl'))
            torch.save(model.state_dict(), '{}/model.pkl'.format(dir))
            if opt.self_adv == 'grad':
                torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl'))
            elif opt.self_adv == 'label':
                torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl'))
                torch.save(model_adv.state_dict(), os.path.join(dir, 'model_adv.pkl'))
            logging.info('New best accuracy: {}'.format(best_acc))


    logging.info("training completed")
def joint_train(data, old_data, opt):

    if not os.path.exists(opt.output):
        os.makedirs(opt.output)

    if opt.pretrained_model_dir != 'None':
        seq_model = SeqModel(data)
        if opt.test_in_cpu:
            seq_model.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        'ner_model.pkl'),
                           map_location='cpu'))
        else:
            seq_model.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 'ner_model.pkl')))

        if (data.label_alphabet_size != seq_model.crf.tagset_size)\
                or (data.HP_hidden_dim != seq_model.hidden2tag.weight.size(1)):
            raise RuntimeError("ner_model not compatible")

        seq_wordseq = WordSequence(data, False, True, True, True)

        if ((data.word_emb_dim != seq_wordseq.wordrep.word_embedding.embedding_dim)\
            or (data.char_emb_dim != seq_wordseq.wordrep.char_feature.char_embeddings.embedding_dim)\
            or (data.feature_emb_dims[0] != seq_wordseq.wordrep.feature_embedding_dims[0])\
            or (data.feature_emb_dims[1] != seq_wordseq.wordrep.feature_embedding_dims[1])):
            raise RuntimeError("ner_wordseq not compatible")

        old_seq_wordseq = WordSequence(old_data, False, True, True, True)
        if opt.test_in_cpu:
            old_seq_wordseq.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        'ner_wordseq.pkl'),
                           map_location='cpu'))
        else:
            old_seq_wordseq.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 'ner_wordseq.pkl')))

        # sd = old_seq_wordseq.lstm.state_dict()

        for old, new in zip(old_seq_wordseq.lstm.parameters(),
                            seq_wordseq.lstm.parameters()):
            new.data.copy_(old)

        vocab_size = old_seq_wordseq.wordrep.word_embedding.num_embeddings
        seq_wordseq.wordrep.word_embedding.weight.data[
            0:
            vocab_size, :] = old_seq_wordseq.wordrep.word_embedding.weight.data[
                0:vocab_size, :]

        vocab_size = old_seq_wordseq.wordrep.char_feature.char_embeddings.num_embeddings
        seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[
            0:
            vocab_size, :] = old_seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[
                0:vocab_size, :]

        for i, feature_embedding in enumerate(
                old_seq_wordseq.wordrep.feature_embeddings):
            vocab_size = feature_embedding.num_embeddings
            seq_wordseq.wordrep.feature_embeddings[i].weight.data[
                0:vocab_size, :] = feature_embedding.weight.data[
                    0:vocab_size, :]

        # for word in data.word_alphabet.iteritems():
        #
        #     old_seq_wordseq.wordrep.word_embedding.weight.data data.word_alphabet.get_index(word)
        #

        classify_wordseq = WordSequence(data, True, False, True, False)

        if ((data.word_emb_dim != classify_wordseq.wordrep.word_embedding.embedding_dim)\
            or (data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']] != classify_wordseq.wordrep.position1_emb.embedding_dim)\
            or (data.feature_emb_dims[1] != classify_wordseq.wordrep.feature_embedding_dims[0])):
            raise RuntimeError("re_wordseq not compatible")

        old_classify_wordseq = WordSequence(old_data, True, False, True, False)
        if opt.test_in_cpu:
            old_classify_wordseq.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        're_wordseq.pkl'),
                           map_location='cpu'))
        else:
            old_classify_wordseq.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 're_wordseq.pkl')))

        for old, new in zip(old_classify_wordseq.lstm.parameters(),
                            classify_wordseq.lstm.parameters()):
            new.data.copy_(old)

        vocab_size = old_classify_wordseq.wordrep.word_embedding.num_embeddings
        classify_wordseq.wordrep.word_embedding.weight.data[
            0:
            vocab_size, :] = old_classify_wordseq.wordrep.word_embedding.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_wordseq.wordrep.position1_emb.num_embeddings
        classify_wordseq.wordrep.position1_emb.weight.data[
            0:
            vocab_size, :] = old_classify_wordseq.wordrep.position1_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_wordseq.wordrep.position2_emb.num_embeddings
        classify_wordseq.wordrep.position2_emb.weight.data[
            0:
            vocab_size, :] = old_classify_wordseq.wordrep.position2_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_wordseq.wordrep.feature_embeddings[
            0].num_embeddings
        classify_wordseq.wordrep.feature_embeddings[0].weight.data[
            0:vocab_size, :] = old_classify_wordseq.wordrep.feature_embeddings[
                0].weight.data[0:vocab_size, :]

        classify_model = ClassifyModel(data)

        old_classify_model = ClassifyModel(old_data)
        if opt.test_in_cpu:
            old_classify_model.load_state_dict(
                torch.load(os.path.join(opt.pretrained_model_dir,
                                        're_model.pkl'),
                           map_location='cpu'))
        else:
            old_classify_model.load_state_dict(
                torch.load(
                    os.path.join(opt.pretrained_model_dir, 're_model.pkl')))

        if (data.re_feature_alphabet_sizes[data.re_feature_name2id['[RELATION]']] != old_classify_model.linear.weight.size(0)
            or (data.re_feature_emb_dims[data.re_feature_name2id['[ENTITY_TYPE]']] != old_classify_model.entity_type_emb.embedding_dim) \
                or (data.re_feature_emb_dims[
                        data.re_feature_name2id['[ENTITY]']] != old_classify_model.entity_emb.embedding_dim)
                or (data.re_feature_emb_dims[
                        data.re_feature_name2id['[TOKEN_NUM]']] != old_classify_model.tok_num_betw_emb.embedding_dim) \
                or (data.re_feature_emb_dims[
                        data.re_feature_name2id['[ENTITY_NUM]']] != old_classify_model.et_num_emb.embedding_dim) \
                ):
            raise RuntimeError("re_model not compatible")

        vocab_size = old_classify_model.entity_type_emb.num_embeddings
        classify_model.entity_type_emb.weight.data[
            0:vocab_size, :] = old_classify_model.entity_type_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_model.entity_emb.num_embeddings
        classify_model.entity_emb.weight.data[
            0:vocab_size, :] = old_classify_model.entity_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_model.tok_num_betw_emb.num_embeddings
        classify_model.tok_num_betw_emb.weight.data[
            0:vocab_size, :] = old_classify_model.tok_num_betw_emb.weight.data[
                0:vocab_size, :]

        vocab_size = old_classify_model.et_num_emb.num_embeddings
        classify_model.et_num_emb.weight.data[
            0:vocab_size, :] = old_classify_model.et_num_emb.weight.data[
                0:vocab_size, :]

    else:
        seq_model = SeqModel(data)
        seq_wordseq = WordSequence(data, False, True, True, True)

        classify_wordseq = WordSequence(data, True, False, True, False)
        classify_model = ClassifyModel(data)

    iter_parameter = itertools.chain(
        *map(list, [seq_wordseq.parameters(),
                    seq_model.parameters()]))
    seq_optimizer = optim.Adam(iter_parameter,
                               lr=data.HP_lr,
                               weight_decay=data.HP_l2)
    iter_parameter = itertools.chain(*map(
        list, [classify_wordseq.parameters(),
               classify_model.parameters()]))
    classify_optimizer = optim.Adam(iter_parameter,
                                    lr=data.HP_lr,
                                    weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(seq_wordseq.wordrep.word_embedding)
        my_utils.freeze_net(classify_wordseq.wordrep.word_embedding)

    re_X_positive = []
    re_Y_positive = []
    re_X_negative = []
    re_Y_negative = []
    relation_vocab = data.re_feature_alphabets[
        data.re_feature_name2id['[RELATION]']]
    my_collate = my_utils.sorted_collate1
    for i in range(len(data.re_train_X)):
        x = data.re_train_X[i]
        y = data.re_train_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            re_X_positive.append(x)
            re_Y_positive.append(y)
        else:
            re_X_negative.append(x)
            re_Y_negative.append(y)

    re_dev_loader = DataLoader(my_utils.RelationDataset(
        data.re_dev_X, data.re_dev_Y),
                               data.HP_batch_size,
                               shuffle=False,
                               collate_fn=my_collate)
    # re_test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate)

    best_ner_score = -1
    best_re_score = -1
    count_performance_not_grow = 0

    for idx in range(data.HP_iteration):
        epoch_start = time.time()

        seq_wordseq.train()
        seq_wordseq.zero_grad()
        seq_model.train()
        seq_model.zero_grad()

        classify_wordseq.train()
        classify_wordseq.zero_grad()
        classify_model.train()
        classify_model.zero_grad()

        batch_size = data.HP_batch_size

        random.shuffle(data.train_Ids)
        ner_train_num = len(data.train_Ids)
        ner_total_batch = ner_train_num // batch_size + 1

        re_train_loader, re_train_iter = makeRelationDataset(
            re_X_positive, re_Y_positive, re_X_negative, re_Y_negative,
            data.unk_ratio, True, my_collate, data.HP_batch_size)
        re_total_batch = len(re_train_loader)

        total_batch = max(ner_total_batch, re_total_batch)
        min_batch = min(ner_total_batch, re_total_batch)

        for batch_id in range(total_batch):

            if batch_id < ner_total_batch:
                start = batch_id * batch_size
                end = (batch_id + 1) * batch_size
                if end > ner_train_num:
                    end = ner_train_num
                instance = data.train_Ids[start:end]
                batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \
                    batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                hidden = seq_wordseq.forward(batch_word, batch_features,
                                             batch_wordlen, batch_char,
                                             batch_charlen, batch_charrecover,
                                             None, None)
                hidden_adv = None
                loss, tag_seq = seq_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_label, mask)
                loss.backward()
                seq_optimizer.step()
                seq_wordseq.zero_grad()
                seq_model.zero_grad()

            if batch_id < re_total_batch:
                [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
                 batch_char, batch_charlen, batch_charrecover, \
                 position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
                 tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(
                    re_train_loader, re_train_iter)

                hidden = classify_wordseq.forward(batch_word, batch_features,
                                                  batch_wordlen, batch_char,
                                                  batch_charlen,
                                                  batch_charrecover,
                                                  position1_seq_tensor,
                                                  position2_seq_tensor)
                hidden_adv = None
                loss, pred = classify_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_wordlen, e1_token, e1_length,
                    e2_token, e2_length, e1_type, e2_type, tok_num_betw,
                    et_num, targets)
                loss.backward()
                classify_optimizer.step()
                classify_wordseq.zero_grad()
                classify_model.zero_grad()

        epoch_finish = time.time()
        logging.info("epoch: %s training finished. Time: %.2fs" %
                     (idx, epoch_finish - epoch_start))

        _, _, _, _, ner_score, _, _ = ner.evaluate(data, seq_wordseq,
                                                   seq_model, "dev")
        logging.info("ner evaluate: f: %.4f" % (ner_score))
        if ner_score > best_ner_score:
            logging.info("new best score: ner: %.4f" % (ner_score))
            best_ner_score = ner_score

            torch.save(seq_wordseq.state_dict(),
                       os.path.join(opt.output, 'ner_wordseq.pkl'))
            torch.save(seq_model.state_dict(),
                       os.path.join(opt.output, 'ner_model.pkl'))

            count_performance_not_grow = 0

            # _, _, _, _, test_ner_score, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test")
            # logging.info("ner evaluate on test: f: %.4f" % (test_ner_score))

        else:
            count_performance_not_grow += 1

        re_score = relation_extraction.evaluate(classify_wordseq,
                                                classify_model, re_dev_loader)
        logging.info("re evaluate: f: %.4f" % (re_score))
        if re_score > best_re_score:
            logging.info("new best score: re: %.4f" % (re_score))
            best_re_score = re_score

            torch.save(classify_wordseq.state_dict(),
                       os.path.join(opt.output, 're_wordseq.pkl'))
            torch.save(classify_model.state_dict(),
                       os.path.join(opt.output, 're_model.pkl'))

            count_performance_not_grow = 0

            # test_re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_test_loader)
            # logging.info("re evaluate on test: f: %.4f" % (test_re_score))
        else:
            count_performance_not_grow += 1

        if count_performance_not_grow > 2 * data.patience:
            logging.info("early stop")
            break

    logging.info("train finished")
Exemple #7
0
def train():

    logging.info("loading ... vocab")
    word_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'word_vocab.pkl'), 'rb'))
    postag_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'postag_vocab.pkl'), 'rb'))
    relation_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'relation_vocab.pkl'), 'rb'))
    entity_type_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'entity_type_vocab.pkl'), 'rb'))
    entity_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'entity_vocab.pkl'), 'rb'))
    position_vocab1 = pickle.load(
        open(os.path.join(opt.pretrain, 'position_vocab1.pkl'), 'rb'))
    position_vocab2 = pickle.load(
        open(os.path.join(opt.pretrain, 'position_vocab2.pkl'), 'rb'))
    tok_num_betw_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'tok_num_betw_vocab.pkl'), 'rb'))
    et_num_vocab = pickle.load(
        open(os.path.join(opt.pretrain, 'et_num_vocab.pkl'), 'rb'))

    # One relation instance is composed of X (a pair of entities and their context), Y (relation label).
    train_X = pickle.load(open(os.path.join(opt.pretrain, 'train_X.pkl'),
                               'rb'))
    train_Y = pickle.load(open(os.path.join(opt.pretrain, 'train_Y.pkl'),
                               'rb'))
    logging.info("total training instance {}".format(len(train_Y)))

    my_collate = my_utils.sorted_collate if opt.model == 'lstm' else my_utils.unsorted_collate

    if opt.strategy == 'all':
        train_loader = DataLoader(my_utils.RelationDataset(train_X, train_Y),
                                  opt.batch_size,
                                  shuffle=True,
                                  collate_fn=my_collate)
        train_iter = iter(train_loader)
        num_iter = len(train_loader)
    elif opt.strategy == 'no-unk':
        train_loader, train_iter = makeDatasetWithoutUnknown(
            train_X, train_Y, relation_vocab, True, my_collate)
        num_iter = len(train_loader)
    elif opt.strategy == 'balance':
        train_loaders, train_iters, train_numbers = makeDatasetForEachClass(
            train_X, train_Y, relation_vocab, my_collate)
        for t in train_numbers:
            logging.info(t)
        # use the median number of instance number of all classes except unknown
        num_iter = int(
            np.median(
                np.array([
                    num for class_name, num in train_numbers
                    if class_name != relation_vocab.unk_tok
                ])))
    elif opt.strategy == 'part-unk':
        train_loader, train_iter = makeDatasetWithoutUnknown(
            train_X, train_Y, relation_vocab, True, my_collate)
        num_iter = len(train_loader)
        unk_loader, unk_iter = makeDatasetUnknown(train_X, train_Y,
                                                  relation_vocab, my_collate,
                                                  opt.unk_ratio)
    else:
        raise RuntimeError("unsupport training strategy")

    test_X = pickle.load(open(os.path.join(opt.pretrain, 'test_X.pkl'), 'rb'))
    test_Y = pickle.load(open(os.path.join(opt.pretrain, 'test_Y.pkl'), 'rb'))
    test_Other = pickle.load(
        open(os.path.join(opt.pretrain, 'test_Other.pkl'), 'rb'))
    logging.info("total test instance {}".format(len(test_Y)))
    #test_loader, _ = makeDatasetWithoutUnknown(test_X, test_Y, relation_vocab, False, my_collate)
    test_loader = DataLoader(my_utils.RelationDataset(test_X, test_Y),
                             opt.batch_size,
                             shuffle=False,
                             collate_fn=my_collate)  # drop_last=True

    if opt.model.lower() == 'lstm':
        feature_extractor = LSTMFeatureExtractor(word_vocab, position_vocab1,
                                                 position_vocab2, opt.F_layers,
                                                 opt.shared_hidden_size,
                                                 opt.dropout)
    elif opt.model.lower() == 'cnn':
        feature_extractor = CNNFeatureExtractor(word_vocab, postag_vocab,
                                                position_vocab1,
                                                position_vocab2, opt.F_layers,
                                                opt.shared_hidden_size,
                                                opt.kernel_num,
                                                opt.kernel_sizes, opt.dropout)
    else:
        raise RuntimeError('Unknown feature extractor {}'.format(opt.model))
    if torch.cuda.is_available():
        feature_extractor = feature_extractor.cuda(opt.gpu)

    if opt.model_high == 'capsule':
        m = capsule.CapsuleNet(opt.shared_hidden_size, relation_vocab,
                               entity_type_vocab, entity_vocab,
                               tok_num_betw_vocab, et_num_vocab)
    elif opt.model_high == 'capsule_em':
        m = capsule_em.CapsuleNet_EM(opt.shared_hidden_size, relation_vocab)
    elif opt.model_high == 'mlp':
        m = baseline.MLP(opt.shared_hidden_size, relation_vocab,
                         entity_type_vocab, entity_vocab, tok_num_betw_vocab,
                         et_num_vocab)
        # m = baseline.SentimentClassifier(opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab,
        #                                  et_num_vocab,
        #                                  opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.model_high_bn)
    else:
        raise RuntimeError('Unknown model {}'.format(opt.model_high))
    if torch.cuda.is_available():
        m = m.cuda(opt.gpu)

    iter_parameter = itertools.chain(
        *map(list, [feature_extractor.parameters(),
                    m.parameters()]))
    optimizer = optim.Adam(iter_parameter, lr=opt.learning_rate)

    if opt.tune_wordemb == False:
        my_utils.freeze_layer(feature_extractor.word_emb)
        #feature_extractor.word_emb.weight.requires_grad = False

    best_acc = 0.0
    logging.info("start training ...")
    for epoch in range(opt.max_epoch):
        feature_extractor.train()
        m.train()
        correct, total = 0, 0

        for i in tqdm(range(num_iter)):

            if opt.strategy == 'all' or opt.strategy == 'no-unk' or opt.strategy == 'part-unk':
                x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(
                    train_loader, train_iter)

            elif opt.strategy == 'balance':
                x2, x1, targets = my_utils.endless_get_next_batch(
                    train_loaders, train_iters)
            else:
                raise RuntimeError("unsupport training strategy")

            hidden_features = feature_extractor.forward(x2, x1)

            if opt.model_high == 'capsule':
                outputs, x_recon = m.forward(hidden_features, x2, x1, targets)
                loss = m.loss(targets, outputs, hidden_features, x2, x1,
                              x_recon, opt.lam_recon)
            elif opt.model_high == 'mlp':
                outputs = m.forward(hidden_features, x2, x1)
                loss = m.loss(targets, outputs)
            else:
                raise RuntimeError('Unknown model {}'.format(opt.model_high))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip)
            optimizer.step()

            total += targets.size(0)
            _, pred = torch.max(outputs, 1)
            correct += (pred == targets).sum().item()

            if opt.strategy == 'part-unk':
                x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(
                    unk_loader, unk_iter)

                hidden_features = feature_extractor.forward(x2, x1)

                if opt.model_high == 'capsule':
                    outputs, x_recon = m.forward(hidden_features, x2, x1,
                                                 targets)
                    loss = m.loss(targets, outputs, hidden_features, x2, x1,
                                  x_recon, opt.lam_recon)
                elif opt.model_high == 'mlp':
                    outputs = m.forward(hidden_features, x2, x1)
                    loss = m.loss(targets, outputs)
                else:
                    raise RuntimeError('Unknown model {}'.format(
                        opt.model_high))

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip)
                optimizer.step()

        if opt.strategy == 'part-unk':
            unk_loader, unk_iter = makeDatasetUnknown(train_X, train_Y,
                                                      relation_vocab,
                                                      my_collate,
                                                      opt.unk_ratio)

        logging.info('epoch {} end'.format(epoch))
        logging.info('Train Accuracy: {}%'.format(100.0 * correct / total))

        test_accuracy = evaluate(feature_extractor, m, test_loader, test_Other)
        # test_accuracy = evaluate(m, test_loader)
        logging.info('Test Accuracy: {}%'.format(test_accuracy))

        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(feature_extractor.state_dict(),
                       '{}/feature_extractor.pth'.format(opt.output))
            torch.save(m.state_dict(), '{}/model.pth'.format(opt.output))
            pickle.dump(test_Other,
                        open(os.path.join(opt.output, 'results.pkl'), "wb"),
                        True)
            logging.info('New best accuracy: {}'.format(best_acc))

    logging.info("training completed")
Exemple #8
0
def train(data, ner_dir, re_dir):

    task_num = 2
    init_alpha = 1.0 / task_num
    if opt.hidden_num <= 1:
        step_alpha = 0
    else:
        step_alpha = (1.0 - init_alpha) / (opt.hidden_num - 1)

    ner_wordrep = WordRep(data, False, True, True, data.use_char)
    ner_hiddenlist = []
    for i in range(opt.hidden_num):
        if i == 0:
            input_size = data.word_emb_dim+data.HP_char_hidden_dim+data.feature_emb_dims[data.feature_name2id['[Cap]']]+ \
                         data.feature_emb_dims[data.feature_name2id['[POS]']]
            output_size = data.HP_hidden_dim
        else:
            input_size = data.HP_hidden_dim
            output_size = data.HP_hidden_dim

        temp = HiddenLayer(data, input_size, output_size)
        ner_hiddenlist.append(temp)

    seq_model = SeqModel(data)

    re_wordrep = WordRep(data, True, False, True, False)
    re_hiddenlist = []
    for i in range(opt.hidden_num):
        if i == 0:
            input_size = data.word_emb_dim + data.feature_emb_dims[data.feature_name2id['[POS]']]+\
                         2*data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']]
            output_size = data.HP_hidden_dim
        else:
            input_size = data.HP_hidden_dim
            output_size = data.HP_hidden_dim

        temp = HiddenLayer(data, input_size, output_size)
        re_hiddenlist.append(temp)

    classify_model = ClassifyModel(data)

    iter_parameter = itertools.chain(
        *map(list, [ner_wordrep.parameters(),
                    seq_model.parameters()] +
             [f.parameters() for f in ner_hiddenlist]))
    ner_optimizer = optim.Adam(iter_parameter,
                               lr=data.HP_lr,
                               weight_decay=data.HP_l2)
    iter_parameter = itertools.chain(
        *map(list, [re_wordrep.parameters(),
                    classify_model.parameters()] +
             [f.parameters() for f in re_hiddenlist]))
    re_optimizer = optim.Adam(iter_parameter,
                              lr=data.HP_lr,
                              weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(ner_wordrep.word_embedding)
        my_utils.freeze_net(re_wordrep.word_embedding)

    re_X_positive = []
    re_Y_positive = []
    re_X_negative = []
    re_Y_negative = []
    relation_vocab = data.re_feature_alphabets[
        data.re_feature_name2id['[RELATION]']]
    my_collate = my_utils.sorted_collate1
    for i in range(len(data.re_train_X)):
        x = data.re_train_X[i]
        y = data.re_train_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            re_X_positive.append(x)
            re_Y_positive.append(y)
        else:
            re_X_negative.append(x)
            re_Y_negative.append(y)

    re_test_loader = DataLoader(my_utils.RelationDataset(
        data.re_test_X, data.re_test_Y),
                                data.HP_batch_size,
                                shuffle=False,
                                collate_fn=my_collate)

    best_ner_score = -1
    best_re_score = -1

    for idx in range(data.HP_iteration):
        epoch_start = time.time()

        ner_wordrep.train()
        ner_wordrep.zero_grad()
        for hidden_layer in ner_hiddenlist:
            hidden_layer.train()
            hidden_layer.zero_grad()
        seq_model.train()
        seq_model.zero_grad()

        re_wordrep.train()
        re_wordrep.zero_grad()
        for hidden_layer in re_hiddenlist:
            hidden_layer.train()
            hidden_layer.zero_grad()
        classify_model.train()
        classify_model.zero_grad()

        batch_size = data.HP_batch_size

        random.shuffle(data.train_Ids)
        ner_train_num = len(data.train_Ids)
        ner_total_batch = ner_train_num // batch_size + 1

        re_train_loader, re_train_iter = makeRelationDataset(
            re_X_positive, re_Y_positive, re_X_negative, re_Y_negative,
            data.unk_ratio, True, my_collate, data.HP_batch_size)
        re_total_batch = len(re_train_loader)

        total_batch = max(ner_total_batch, re_total_batch)
        min_batch = min(ner_total_batch, re_total_batch)

        for batch_id in range(total_batch):

            if batch_id < min_batch:
                start = batch_id * batch_size
                end = (batch_id + 1) * batch_size
                if end > ner_train_num:
                    end = ner_train_num
                instance = data.train_Ids[start:end]
                ner_batch_word, ner_batch_features, ner_batch_wordlen, ner_batch_wordrecover, ner_batch_char, ner_batch_charlen, \
                ner_batch_charrecover, ner_batch_label, ner_mask, ner_batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                [re_batch_word, re_batch_features, re_batch_wordlen, re_batch_wordrecover, re_batch_char, re_batch_charlen,
                 re_batch_charrecover, re_position1_seq_tensor, re_position2_seq_tensor, re_e1_token, re_e1_length, re_e2_token, re_e2_length,
                 re_e1_type, re_e2_type, re_tok_num_betw, re_et_num], [re_targets, re_targets_permute] = \
                    my_utils.endless_get_next_batch_without_rebatch1(re_train_loader, re_train_iter)

                if ner_batch_word.size(0) != re_batch_word.size(0):
                    continue  # if batch size is not equal, we ignore such batch

                ner_word_rep = ner_wordrep.forward(
                    ner_batch_word, ner_batch_features, ner_batch_wordlen,
                    ner_batch_char, ner_batch_charlen, ner_batch_charrecover,
                    None, None)

                re_word_rep = re_wordrep.forward(
                    re_batch_word, re_batch_features, re_batch_wordlen,
                    re_batch_char, re_batch_charlen, re_batch_charrecover,
                    re_position1_seq_tensor, re_position2_seq_tensor)
                alpha = init_alpha
                ner_hidden = ner_word_rep
                re_hidden = re_word_rep
                for i in range(opt.hidden_num):
                    if alpha > 0.99:
                        use_attn = False

                        ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward(
                            ner_hidden, ner_batch_wordlen, use_attn)
                        re_lstm_out, re_att_out = re_hiddenlist[i].forward(
                            re_hidden, re_batch_wordlen, use_attn)

                        ner_hidden = ner_lstm_out
                        re_hidden = re_lstm_out
                    else:
                        use_attn = True

                        ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward(
                            ner_hidden, ner_batch_wordlen, use_attn)
                        re_lstm_out, re_att_out = re_hiddenlist[i].forward(
                            re_hidden, re_batch_wordlen, use_attn)

                        ner_hidden = alpha * ner_lstm_out + (
                            1 - alpha) * re_att_out.unsqueeze(1)
                        re_hidden = alpha * re_lstm_out + (
                            1 - alpha) * ner_att_out.unsqueeze(1)

                    alpha += step_alpha

                ner_loss, ner_tag_seq = seq_model.neg_log_likelihood_loss(
                    ner_hidden, ner_batch_label, ner_mask)
                re_loss, re_pred = classify_model.neg_log_likelihood_loss(
                    re_hidden, re_batch_wordlen, re_e1_token, re_e1_length,
                    re_e2_token, re_e2_length, re_e1_type, re_e2_type,
                    re_tok_num_betw, re_et_num, re_targets)

                ner_loss.backward(retain_graph=True)
                re_loss.backward()

                ner_optimizer.step()
                re_optimizer.step()

                ner_wordrep.zero_grad()
                for hidden_layer in ner_hiddenlist:
                    hidden_layer.zero_grad()
                seq_model.zero_grad()

                re_wordrep.zero_grad()
                for hidden_layer in re_hiddenlist:
                    hidden_layer.zero_grad()
                classify_model.zero_grad()

            else:

                if batch_id < ner_total_batch:
                    start = batch_id * batch_size
                    end = (batch_id + 1) * batch_size
                    if end > ner_train_num:
                        end = ner_train_num
                    instance = data.train_Ids[start:end]
                    batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \
                        batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                    ner_word_rep = ner_wordrep.forward(
                        batch_word, batch_features, batch_wordlen, batch_char,
                        batch_charlen, batch_charrecover, None, None)

                    ner_hidden = ner_word_rep
                    for i in range(opt.hidden_num):
                        ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward(
                            ner_hidden, batch_wordlen, False)
                        ner_hidden = ner_lstm_out

                    loss, tag_seq = seq_model.neg_log_likelihood_loss(
                        ner_hidden, batch_label, mask)

                    loss.backward()
                    ner_optimizer.step()
                    ner_wordrep.zero_grad()
                    for hidden_layer in ner_hiddenlist:
                        hidden_layer.zero_grad()
                    seq_model.zero_grad()

                if batch_id < re_total_batch:
                    [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
                     batch_char, batch_charlen, batch_charrecover, \
                     position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
                     tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(
                        re_train_loader, re_train_iter)

                    re_word_rep = re_wordrep.forward(
                        batch_word, batch_features, batch_wordlen, batch_char,
                        batch_charlen, batch_charrecover, position1_seq_tensor,
                        position2_seq_tensor)

                    re_hidden = re_word_rep
                    for i in range(opt.hidden_num):
                        re_lstm_out, re_att_out = re_hiddenlist[i].forward(
                            re_hidden, batch_wordlen, False)
                        re_hidden = re_lstm_out

                    loss, pred = classify_model.neg_log_likelihood_loss(
                        re_hidden, batch_wordlen, e1_token, e1_length,
                        e2_token, e2_length, e1_type, e2_type, tok_num_betw,
                        et_num, targets)
                    loss.backward()
                    re_optimizer.step()
                    re_wordrep.zero_grad()
                    for hidden_layer in re_hiddenlist:
                        hidden_layer.zero_grad()
                    classify_model.zero_grad()

        epoch_finish = time.time()
        print("epoch: %s training finished. Time: %.2fs" %
              (idx, epoch_finish - epoch_start))

        ner_score = ner_evaluate(data, ner_wordrep, ner_hiddenlist, seq_model,
                                 "test")
        print("ner evaluate: f: %.4f" % (ner_score))

        re_score = re_evaluate(re_wordrep, re_hiddenlist, classify_model,
                               re_test_loader)
        print("re evaluate: f: %.4f" % (re_score))

        if ner_score + re_score > best_ner_score + best_re_score:
            print("new best score: ner: %.4f , re: %.4f" %
                  (ner_score, re_score))
            best_ner_score = ner_score
            best_re_score = re_score

            torch.save(ner_wordrep.state_dict(),
                       os.path.join(ner_dir, 'wordrep.pkl'))
            for i, hidden_layer in enumerate(ner_hiddenlist):
                torch.save(hidden_layer.state_dict(),
                           os.path.join(ner_dir, 'hidden_{}.pkl'.format(i)))
            torch.save(seq_model.state_dict(),
                       os.path.join(ner_dir, 'model.pkl'))

            torch.save(re_wordrep.state_dict(),
                       os.path.join(re_dir, 'wordrep.pkl'))
            for i, hidden_layer in enumerate(re_hiddenlist):
                torch.save(hidden_layer.state_dict(),
                           os.path.join(re_dir, 'hidden_{}.pkl'.format(i)))
            torch.save(classify_model.state_dict(),
                       os.path.join(re_dir, 'model.pkl'))
Exemple #9
0
def pipeline(data, ner_dir, re_dir):

    seq_model = SeqModel(data)
    seq_wordseq = WordSequence(data, False, True, True, data.use_char)

    classify_wordseq = WordSequence(data, True, False, True, False)
    classify_model = ClassifyModel(data)
    if torch.cuda.is_available():
        classify_model = classify_model.cuda(data.HP_gpu)

    iter_parameter = itertools.chain(
        *map(list, [seq_wordseq.parameters(),
                    seq_model.parameters()]))
    seq_optimizer = optim.Adam(iter_parameter,
                               lr=opt.ner_lr,
                               weight_decay=data.HP_l2)
    iter_parameter = itertools.chain(*map(
        list, [classify_wordseq.parameters(),
               classify_model.parameters()]))
    classify_optimizer = optim.Adam(iter_parameter,
                                    lr=opt.re_lr,
                                    weight_decay=data.HP_l2)

    if data.tune_wordemb == False:
        my_utils.freeze_net(seq_wordseq.wordrep.word_embedding)
        my_utils.freeze_net(classify_wordseq.wordrep.word_embedding)

    re_X_positive = []
    re_Y_positive = []
    re_X_negative = []
    re_Y_negative = []
    relation_vocab = data.re_feature_alphabets[
        data.re_feature_name2id['[RELATION]']]
    my_collate = my_utils.sorted_collate1
    for i in range(len(data.re_train_X)):
        x = data.re_train_X[i]
        y = data.re_train_Y[i]

        if y != relation_vocab.get_index("</unk>"):
            re_X_positive.append(x)
            re_Y_positive.append(y)
        else:
            re_X_negative.append(x)
            re_Y_negative.append(y)

    re_test_loader = DataLoader(my_utils.RelationDataset(
        data.re_test_X, data.re_test_Y),
                                data.HP_batch_size,
                                shuffle=False,
                                collate_fn=my_collate)

    best_ner_score = -1
    best_re_score = -1

    for idx in range(data.HP_iteration):
        epoch_start = time.time()

        seq_wordseq.train()
        seq_wordseq.zero_grad()
        seq_model.train()
        seq_model.zero_grad()

        classify_wordseq.train()
        classify_wordseq.zero_grad()
        classify_model.train()
        classify_model.zero_grad()

        batch_size = data.HP_batch_size

        random.shuffle(data.train_Ids)
        ner_train_num = len(data.train_Ids)
        ner_total_batch = ner_train_num // batch_size + 1

        re_train_loader, re_train_iter = makeRelationDataset(
            re_X_positive, re_Y_positive, re_X_negative, re_Y_negative,
            data.unk_ratio, True, my_collate, data.HP_batch_size)
        re_total_batch = len(re_train_loader)

        total_batch = max(ner_total_batch, re_total_batch)
        min_batch = min(ner_total_batch, re_total_batch)

        for batch_id in range(total_batch):

            if batch_id < ner_total_batch:
                start = batch_id * batch_size
                end = (batch_id + 1) * batch_size
                if end > ner_train_num:
                    end = ner_train_num
                instance = data.train_Ids[start:end]
                batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \
                    batch_permute_label = batchify_with_label(instance, data.HP_gpu)

                hidden = seq_wordseq.forward(batch_word, batch_features,
                                             batch_wordlen, batch_char,
                                             batch_charlen, batch_charrecover,
                                             None, None)
                hidden_adv = None
                loss, tag_seq = seq_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_label, mask)
                loss.backward()
                seq_optimizer.step()
                seq_wordseq.zero_grad()
                seq_model.zero_grad()

            if batch_id < re_total_batch:
                [batch_word, batch_features, batch_wordlen, batch_wordrecover, \
                 batch_char, batch_charlen, batch_charrecover, \
                 position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \
                 tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(
                    re_train_loader, re_train_iter)

                hidden = classify_wordseq.forward(batch_word, batch_features,
                                                  batch_wordlen, batch_char,
                                                  batch_charlen,
                                                  batch_charrecover,
                                                  position1_seq_tensor,
                                                  position2_seq_tensor)
                hidden_adv = None
                loss, pred = classify_model.neg_log_likelihood_loss(
                    hidden, hidden_adv, batch_wordlen, e1_token, e1_length,
                    e2_token, e2_length, e1_type, e2_type, tok_num_betw,
                    et_num, targets)
                loss.backward()
                classify_optimizer.step()
                classify_wordseq.zero_grad()
                classify_model.zero_grad()

        epoch_finish = time.time()
        print("epoch: %s training finished. Time: %.2fs" %
              (idx, epoch_finish - epoch_start))

        # _, _, _, _, f, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test")
        ner_score = ner.evaluate1(data, seq_wordseq, seq_model, "test")
        print("ner evaluate: f: %.4f" % (ner_score))

        re_score = relation_extraction.evaluate(classify_wordseq,
                                                classify_model, re_test_loader)
        print("re evaluate: f: %.4f" % (re_score))

        if ner_score + re_score > best_ner_score + best_re_score:
            print("new best score: ner: %.4f , re: %.4f" %
                  (ner_score, re_score))
            best_ner_score = ner_score
            best_re_score = re_score

            torch.save(seq_wordseq.state_dict(),
                       os.path.join(ner_dir, 'wordseq.pkl'))
            torch.save(seq_model.state_dict(),
                       os.path.join(ner_dir, 'model.pkl'))
            torch.save(classify_wordseq.state_dict(),
                       os.path.join(re_dir, 'wordseq.pkl'))
            torch.save(classify_model.state_dict(),
                       os.path.join(re_dir, 'model.pkl'))
Exemple #10
0
def train(other_dir):

    logging.info("loading ... vocab")
    word_vocab = pickle.load(open(os.path.join(opt.pretrain, 'word_vocab.pkl'), 'rb'))
    postag_vocab = pickle.load(open(os.path.join(opt.pretrain, 'postag_vocab.pkl'), 'rb'))
    relation_vocab = pickle.load(open(os.path.join(opt.pretrain, 'relation_vocab.pkl'), 'rb'))
    entity_type_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_type_vocab.pkl'), 'rb'))
    entity_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_vocab.pkl'), 'rb'))
    position_vocab1 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab1.pkl'), 'rb'))
    position_vocab2 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab2.pkl'), 'rb'))
    tok_num_betw_vocab = pickle.load(open(os.path.join(opt.pretrain, 'tok_num_betw_vocab.pkl'), 'rb'))
    et_num_vocab = pickle.load(open(os.path.join(opt.pretrain, 'et_num_vocab.pkl'), 'rb'))


    my_collate = my_utils.sorted_collate if opt.model == 'lstm' else my_utils.unsorted_collate

    # test only on the main domain
    test_X = pickle.load(open(os.path.join(opt.pretrain, 'test_X.pkl'), 'rb'))
    test_Y = pickle.load(open(os.path.join(opt.pretrain, 'test_Y.pkl'), 'rb'))
    test_Other = pickle.load(open(os.path.join(opt.pretrain, 'test_Other.pkl'), 'rb'))
    logging.info("total test instance {}".format(len(test_Y)))
    test_loader = DataLoader(my_utils.RelationDataset(test_X, test_Y),
                              opt.batch_size, shuffle=False, collate_fn=my_collate) # drop_last=True

    # train on the main as well as other domains
    domains = ['main']
    train_loaders, train_iters, unk_loaders, unk_iters = {}, {}, {}, {}
    train_X = pickle.load(open(os.path.join(opt.pretrain, 'train_X.pkl'), 'rb'))
    train_Y = pickle.load(open(os.path.join(opt.pretrain, 'train_Y.pkl'), 'rb'))
    logging.info("total training instance {}".format(len(train_Y)))
    train_loaders['main'], train_iters['main'] = makeDatasetWithoutUnknown(train_X, train_Y, relation_vocab, True, my_collate)
    unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio)

    for other in other_dir:
        domains.append(other)
        other_X = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_X.pkl'.format(other)), 'rb'))
        other_Y = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_Y.pkl'.format(other)), 'rb'))
        logging.info("other {} instance {}".format(other, len(other_Y)))
        train_loaders[other], train_iters[other] = makeDatasetWithoutUnknown(other_X, other_Y, relation_vocab, True, my_collate)
        unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio)

    opt.domains = domains


    F_s = None
    F_d = {}
    if opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2,
                                                 opt.F_layers, opt.shared_hidden_size, opt.dropout)
        for domain in opt.domains:
            F_d[domain] = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2,
                                                 opt.F_layers, opt.shared_hidden_size, opt.dropout)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2,
                                                opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes, opt.dropout)
        for domain in opt.domains:
            F_d[domain] = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2,
                                                opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes, opt.dropout)
    else:
        raise RuntimeError('Unknown feature extractor {}'.format(opt.model))

    if opt.model_high == 'capsule':
        C = capsule.CapsuleNet(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab,
                                         et_num_vocab)
    elif opt.model_high == 'mlp':
        C = baseline.MLP(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab,
                                         et_num_vocab)
    else:
        raise RuntimeError('Unknown model {}'.format(opt.model_high))

    if opt.adv:
        D = DomainClassifier(1, opt.shared_hidden_size, opt.shared_hidden_size,
                             len(opt.domains), opt.loss, opt.dropout, True)

    if torch.cuda.is_available():
        F_s, C = F_s.cuda(opt.gpu), C.cuda(opt.gpu)
        for f_d in F_d.values():
            f_d = f_d.cuda(opt.gpu)
        if opt.adv:
            D = D.cuda(opt.gpu)

    iter_parameter = itertools.chain(
        *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()]))
    optimizer = optim.Adam(iter_parameter, lr=opt.learning_rate)
    if opt.adv:
        optimizerD = optim.Adam(D.parameters(), lr=opt.learning_rate)


    best_acc = 0.0
    logging.info("start training ...")
    for epoch in range(opt.max_epoch):

        F_s.train()
        C.train()
        for f in F_d.values():
            f.train()
        if opt.adv:
            D.train()

        # domain accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        if opt.adv:
            d_correct, d_total = 0, 0

        # conceptually view 1 epoch as 1 epoch of the main domain
        num_iter = len(train_loaders['main'])

        for i in tqdm(range(num_iter)):

            if opt.adv:
                # D iterations
                my_utils.freeze_net(F_s)
                map(my_utils.freeze_net, F_d.values())
                my_utils.freeze_net(C)
                my_utils.unfreeze_net(D)

                if opt.tune_wordemb == False:
                    my_utils.freeze_net(F_s.word_emb)
                    for f_d in F_d.values():
                        my_utils.freeze_net(f_d.word_emb)
                # WGAN n_critic trick since D trains slower
                n_critic = opt.n_critic
                if opt.wgan_trick:
                    if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0):
                        n_critic = 100

                for _ in range(n_critic):
                    D.zero_grad()

                    # train on both labeled and unlabeled domains
                    for domain in opt.domains:
                        # targets not used
                        x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain],
                                                                                       train_iters[domain])
                        d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1]))
                        shared_feat = F_s(x2, x1)
                        d_outputs = D(shared_feat)
                        # D accuracy
                        _, pred = torch.max(d_outputs, 1)
                        d_total += len(x2[1])
                        if opt.loss.lower() == 'l2':
                            _, tgt_indices = torch.max(d_targets, 1)
                            d_correct += (pred == tgt_indices).sum().data.item()
                            l_d = functional.mse_loss(d_outputs, d_targets)
                            l_d.backward()
                        else:
                            d_correct += (pred == d_targets).sum().data.item()
                            l_d = functional.nll_loss(d_outputs, d_targets)
                            l_d.backward()

                    optimizerD.step()

            # F&C iteration
            my_utils.unfreeze_net(F_s)
            map(my_utils.unfreeze_net, F_d.values())
            my_utils.unfreeze_net(C)
            if opt.adv:
                my_utils.freeze_net(D)
            if opt.tune_wordemb == False:
                my_utils.freeze_net(F_s.word_emb)
                for f_d in F_d.values():
                    my_utils.freeze_net(f_d.word_emb)

            F_s.zero_grad()
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()

            for domain in opt.domains:

                x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain])
                shared_feat = F_s(x2, x1)
                domain_feat = F_d[domain](x2, x1)
                features = torch.cat((shared_feat, domain_feat), dim=1)

                if opt.model_high == 'capsule':
                    c_outputs, x_recon = C.forward(features, x2, x1, targets)
                    l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon)
                elif opt.model_high == 'mlp':
                    c_outputs = C.forward(features, x2, x1)
                    l_c = C.loss(targets, c_outputs)
                else:
                    raise RuntimeError('Unknown model {}'.format(opt.model_high))

                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().data.item()

                # training with unknown
                x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(unk_loaders[domain], unk_iters[domain])
                shared_feat = F_s(x2, x1)
                domain_feat = F_d[domain](x2, x1)
                features = torch.cat((shared_feat, domain_feat), dim=1)

                if opt.model_high == 'capsule':
                    c_outputs, x_recon = C.forward(features, x2, x1, targets)
                    l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon)
                elif opt.model_high == 'mlp':
                    c_outputs = C.forward(features, x2, x1)
                    l_c = C.loss(targets, c_outputs)
                else:
                    raise RuntimeError('Unknown model {}'.format(opt.model_high))

                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().data.item()

            if opt.adv:
                # update F with D gradients on all domains
                for domain in opt.domains:
                    x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain],
                                                                                   train_iters[domain])
                    shared_feat = F_s(x2, x1)
                    d_outputs = D(shared_feat)
                    if opt.loss.lower() == 'gr':
                        d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1]))
                        l_d = functional.nll_loss(d_outputs, d_targets)
                        if opt.lambd > 0:
                            l_d *= -opt.lambd
                    elif opt.loss.lower() == 'bs':
                        d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1]))
                        l_d = functional.kl_div(d_outputs, d_targets, size_average=False)
                        if opt.lambd > 0:
                            l_d *= opt.lambd
                    elif opt.loss.lower() == 'l2':
                        d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1]))
                        l_d = functional.mse_loss(d_outputs, d_targets)
                        if opt.lambd > 0:
                            l_d *= opt.lambd
                    l_d.backward()

            torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip)
            optimizer.step()

        # regenerate unknown dataset after one epoch
        unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio)
        for other in other_dir:
            unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio)

        logging.info('epoch {} end'.format(epoch))
        if opt.adv and d_total > 0:
            logging.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
        logging.info('Training accuracy:')
        logging.info('\t'.join(opt.domains))
        logging.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.domains]))

        test_accuracy = evaluate(F_s, F_d['main'], C, test_loader, test_Other)
        logging.info('Test Accuracy: {}%'.format(test_accuracy))

        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(F_s.state_dict(), '{}/F_s.pth'.format(opt.output))
            for d in opt.domains:
                torch.save(F_d[d].state_dict(), '{}/F_d_{}.pth'.format(opt.output, d))
            torch.save(C.state_dict(), '{}/C.pth'.format(opt.output))
            if opt.adv:
                torch.save(D.state_dict(), '{}/D.pth'.format(opt.output))
            pickle.dump(test_Other, open(os.path.join(opt.output, 'results.pkl'), "wb"), True)
            logging.info('New best accuracy: {}'.format(best_acc))


    logging.info("training completed")