Beispiel #1
0
def eval_(output_dir, t_labels, p_labels, text):
    with open(os.path.join(output_dir, t_labels), 'r') as t, \
            open(os.path.join(output_dir, p_labels), 'r') as p, \
            open(os.path.join(output_dir, text), 'r') as textf:
        ne_class_list = set()
        true_labels_for_testing = []
        results_of_prediction = []
        for text, true_labels, predicted_labels in zip(textf, t, p):
            true_labels = true_labels.strip().replace('_', '-').split()
            predicted_labels = predicted_labels.strip().replace('_',
                                                                '-').split()
            biluo_tags_true = get_biluo(true_labels)
            biluo_tags_predicted = get_biluo(predicted_labels)
            doc = Doc(text.strip())
            offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
            offset_predicted_labels = offset_from_biluo(
                doc, biluo_tags_predicted)

            ent_labels = dict()
            for ent in offset_true_labels:
                start, stop, ent_type = ent
                ent_type = ent_type.replace('_', '')
                ne_class_list.add(ent_type)
                if ent_type in ent_labels:
                    ent_labels[ent_type].append((start, stop))
                else:
                    ent_labels[ent_type] = [(start, stop)]
            true_labels_for_testing.append(ent_labels)

            ent_labels = dict()
            for ent in offset_predicted_labels:
                start, stop, ent_type = ent
                ent_type = ent_type.replace('_', '')
                if ent_type in ent_labels:
                    ent_labels[ent_type].append((start, stop))
                else:
                    ent_labels[ent_type] = [(start, stop)]
            results_of_prediction.append(ent_labels)

    from eval.quality import calculate_prediction_quality
    print(ne_class_list)
    f1, precision, recall, results = \
        calculate_prediction_quality(true_labels_for_testing,
                                     results_of_prediction,
                                     tuple(ne_class_list))
    print(f1, precision, recall, results)
    def add_knowledge_with_vm(self, args, sent_batch, label_batch):
        """
        input: sent_batch - list of sentences, e.g., ["abcd", "efgh"]
        return: know_sent_batch - list of sentences with entites embedding
                position_batch - list of position index of each character.
                visible_matrix_batch - list of visible matrixs
                seg_batch - list of segment tags
        """
        text_ = sent_batch[0]
        label_ = label_batch[0]

        tag_labels_true = label_.strip().replace('_', '-').split()
        biluo_tags_true = get_biluo(tag_labels_true)

        doc = Doc(text_)
        offset_true_labels = offset_from_biluo(doc, biluo_tags_true)

        chunk_start = 0
        chunks = []

        # Convert text into chunks
        for start, end, _ in offset_true_labels:
            chunk_text = text_[chunk_start: start].strip()
            chunk_entity = text_[start: end].strip()
            chunk_start = end

            if chunk_text:
                chunks.append(chunk_text)

            if chunk_entity:
                chunks.append(chunk_entity)

        # Append the last chunk if not empty
        last_chunk = text_[chunk_start:].strip()
        if last_chunk:
            chunks.append(last_chunk)
        chunks = [chunks]

        know_sent_batch = []
        position_batch = []
        visible_matrix_batch = []
        seg_batch = []
        for split_sent in chunks:
            # create tree
            sent_tree = []
            pos_idx_tree = []
            abs_idx_tree = []
            pos_idx = -1
            abs_idx = -1
            abs_idx_src = []
            # print(split_sent)
            num_chunks = len(split_sent)
            for idx, token_original in enumerate(split_sent):
                know_entities = []
                if args.use_kg:
                    all_entities = list(self.lookup_table.get(token_original.lower(), []))
                    # Select entities from least frequent
                    if args.reverse_order:
                        all_entities_len = len(all_entities)
                        start_index = all_entities_len - args.max_entities
                        know_entities = all_entities[start_index: None]
                    else:
                        # Select entities from frequent features
                        know_entities = all_entities[:args.max_entities]
                    know_entities = [ent.replace('_', ' ') for ent in know_entities]

                # print(entities, token_original)

                # Tokenize the data
                cur_tokens = []
                for tok in token_original.split():
                    cur_tokens.extend(self.tokenize_word(tok))

                if idx == 0:
                    cls_token = self.tokenizer.cls_token
                    cur_tokens = [cls_token] + cur_tokens
                    token_original = cls_token + ' ' + token_original
                if idx == num_chunks - 1:
                    sep_token = self.tokenizer.sep_token
                    cur_tokens = cur_tokens + [sep_token]
                    token_original = token_original + ' ' + sep_token

                entities = []
                for ent in know_entities:
                    entity = []
                    # Check if ent is not empty
                    if ent:
                        for word in ent.split():
                            entity.extend(self.tokenize_word(word))
                        entities.append(entity)

                sent_tree.append((token_original, cur_tokens, entities))

                if token_original in self.special_tags:
                    token_pos_idx = [pos_idx+1]
                    token_abs_idx = [abs_idx+1]
                else:
                    token_pos_idx = [pos_idx+i for i in range(1, len(cur_tokens)+1)]
                    token_abs_idx = [abs_idx+i for i in range(1, len(cur_tokens)+1)]
                # print(token_abs_idx)
                abs_idx = token_abs_idx[-1]

                entities_pos_idx = []
                entities_abs_idx = []
                for ent in entities:
                    try:
                        ent_pos_idx = [token_pos_idx[-1] + i for i in range(1, len(ent)+1)]
                        entities_pos_idx.append(ent_pos_idx)
                        ent_abs_idx = [abs_idx + i for i in range(1, len(ent)+1)]
                        abs_idx = ent_abs_idx[-1]
                        entities_abs_idx.append(ent_abs_idx)
                    except IndexError:
                        print(entities)
                        exit()

                # print(f'token_abs_idx:{token_abs_idx}')
                # print(f'token_pos_idx:{token_pos_idx}')
                # print(f'entities_abs_idx:{entities_abs_idx}')
                # print(f'entities_pos_idx:{entities_pos_idx}')

                pos_idx_tree.append((token_pos_idx, entities_pos_idx))
                pos_idx = token_pos_idx[-1]
                abs_idx_tree.append((token_abs_idx, entities_abs_idx))
                abs_idx_src += token_abs_idx

            # Get know_sent and pos
            # print(abs_idx_tree)
            # print(pos_idx_tree)
            # print(sent_tree)
            # exit()

            know_sent = []
            pos = []
            seg = []
            for i in range(len(sent_tree)):
                token_original = sent_tree[i][0]
                word = sent_tree[i][1]

                for tok in token_original.split():
                    if tok in self.special_tags:
                        seg += [0]
                    else:
                        cur_toks = self.tokenize_word(tok)
                        num_subwords = len(cur_toks)
                        seg += [0]

                        # Add extra tags for the added subtokens
                        if num_subwords > 1:
                            seg += [2] * (num_subwords - 1)

                # Append the subwords in know_sent
                know_sent += word

                pos += pos_idx_tree[i][0]
                for j in range(len(sent_tree[i][2])):
                    add_word = sent_tree[i][2][j]
                    know_sent += add_word
                    seg += [1] * len(add_word)
                    pos += list(pos_idx_tree[i][1][j])

            token_num = len(know_sent)
            # Calculate visible matrix
            visible_matrix = np.zeros((token_num, token_num))
            for item in abs_idx_tree:
                src_ids = item[0]
                for id in src_ids:
                    visible_abs_idx = abs_idx_src + [idx for ent in item[1] for idx in ent]
                    visible_matrix[id, visible_abs_idx] = 1
                for ent in item[1]:
                    for id in ent:
                        visible_abs_idx = ent + src_ids
                        visible_matrix[id, visible_abs_idx] = 1

            # print(know_sent)
            # print(seg)
            # print(pos)
            # print(visible_matrix)
            # exit()
            src_length = len(know_sent)
            if args.padding:
                if len(know_sent) < args.seq_length:
                    pad_num = args.seq_length - src_length
                    know_sent += [self.tokenizer.pad_token] * pad_num
                    seg += [3] * pad_num
                    pos += [args.seq_length - 1] * pad_num
                    visible_matrix = np.pad(visible_matrix, ((0, pad_num), (0, pad_num)), 'constant')  # pad 0
                else:
                    know_sent = know_sent[:args.seq_length]
                    seg = seg[:args.seq_length]
                    pos = pos[:args.seq_length]
                    visible_matrix = visible_matrix[:args.seq_length, :args.seq_length]

            if args.truncate and src_length > args.seq_length:
                know_sent = know_sent[:args.seq_length]
                seg = seg[:args.seq_length]
                pos = pos[:args.seq_length]
                visible_matrix = visible_matrix[:args.seq_length, :args.seq_length]

            # print(know_sent)
            # print(seg)
            # print(pos)
            # print(visible_matrix)
            # exit()
            
            know_sent_batch.append(know_sent)
            position_batch.append(pos)
            visible_matrix_batch.append(visible_matrix)
            seg_batch.append(seg)
        
        return know_sent_batch, position_batch, visible_matrix_batch, seg_batch
Beispiel #3
0
ignored = 0
tag_features = {}
new_texts = []
updated_labels = []
for index, row in data_df.iterrows():
    text_ = row.text
    words = text_.split()
    doc = Doc(text_)
    labels = row.labels
    tag_labels_true = labels.strip().replace('_', '-').split()
    if len(words) != len(tag_labels_true):
        ignored += 1
        # print(index, row.text)
        continue
    biluo_tags_true = get_biluo(tag_labels_true)
    offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
    # if index == 49:
    # print(text_)
    # print(tag_labels_true)
    # print(offset_true_labels)
    new_labels = labels.split()

    chunk_start = 0
    chunks = []
    # Convert text into chunks
    for start, end, _ in offset_true_labels:
        chunk_text = text_[chunk_start:start].strip()
        chunk_entity = text_[start:end].strip()
        chunk_start = end

        if chunk_text:
def train(args):
    vocab_path = os.path.join(args.data_dir, args.vocab)
    tag_path = os.path.join(args.data_dir, args.tag_set)
    word_to_idx, idx_to_word, tag_to_idx, idx_to_tag = load_vocabs(vocab_path, tag_path)
    train_sentences, train_labels, test_sentences, test_labels = prepare_text(args, tag_to_idx)

    device = get_device(args)
    start = time.time()
    bert_embedding1 = TransformerWordEmbeddings('distilbert-base-multilingual-cased',
                                                layers='-1',
                                                batch_size=args.batch_size,
                                                pooling_operation=args.pooling_operation,
                                                )

    bert_embedding2 = TransformerWordEmbeddings('distilroberta-base',
                                                layers='-1',
                                                batch_size=args.batch_size,
                                                pooling_operation=args.pooling_operation,
                                                )

    bert_embedding3 = TransformerWordEmbeddings('sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens',
                                                layers='-1',
                                                batch_size=args.batch_size,
                                                pooling_operation=args.pooling_operation
                                                )

    encoder = StackTransformerEmbeddings([bert_embedding1, bert_embedding2, bert_embedding3])

    train_sentences_encoded = encoder.encode(train_sentences)
    test_sentences_encoded = encoder.encode(test_sentences)

    print(f'Encoding time:{time.time() - start}')

    # Update the Namespace
    args.vocab_size = len(idx_to_word)
    args.number_of_tags = len(idx_to_tag)

    # Update the embedding dim
    args.embedding_dim = encoder.embedding_length

    model = build_model(args, device)
    print(model)
    model = model.to(device)

    # optimizer = torch.optim.Adam(model.parameters())
    betas = (0.9, 0.999)
    eps = 1e-8
    optimizer = BertAdam(model, lr=args.learning_rate, b1=betas[0], b2=betas[1], e=eps)

    pad_id = word_to_idx['PAD']
    pad_id_labels = tag_to_idx['PAD']

    batcher = SamplingBatcherStackedTransformers(np.asarray(train_sentences_encoded, dtype=object),
                                                 np.asarray(train_labels, dtype=object),
                                                 batch_size=args.batch_size,
                                                 pad_id=pad_id,
                                                 pad_id_labels=pad_id_labels,
                                                 embedding_length=encoder.embedding_length,
                                                 device=device)

    updates = 1
    total_loss = 0
    best_loss = +inf
    stop_training = False
    output_dir = args.output_dir
    try:
        os.makedirs(output_dir)
    except:
        pass

    prefix = args.train_text.split('_')[0] if len(args.train_text.split('_')) > 1 \
        else args.train_text.split('.')[0]

    start_time = time.time()
    for epoch in range(args.epochs):
        for batch in batcher:
            updates += 1
            input_, labels, labels_mask = batch
            optimizer.zero_grad()
            loss = model.score(batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.data
            if updates % args.patience == 0:
                print(f'Epoch: {epoch}, Updates:{updates}, Loss: {total_loss}')
                if best_loss > total_loss:
                    save_state(f'{output_dir}/{prefix}_best_model.pt', model, loss_fn, optimizer,
                               updates, args=args)
                    best_loss = total_loss
                total_loss = 0
            if updates % args.max_steps == 0:
                stop_training = True
                break

        if stop_training:
            break

    print('Training time:{}'.format(time.time() - start_time))

    def get_idx_to_tag(label_ids):
        return [idx_to_tag.get(idx) for idx in label_ids]

    def get_idx_to_word(words_ids):
        return [idx_to_word.get(idx) for idx in words_ids]

    model, model_args = load_model_state(f'{output_dir}/{prefix}_best_model.pt', device)
    model = model.to(device)
    batcher_test = SamplingBatcherStackedTransformers(np.asarray(test_sentences_encoded, dtype=object),
                                                      np.asarray(test_labels, dtype=object),
                                                      batch_size=args.batch_size,
                                                      pad_id=pad_id,
                                                      pad_id_labels=pad_id_labels,
                                                      embedding_length=encoder.embedding_length,
                                                      device=device)
    ne_class_list = set()
    true_labels_for_testing = []
    results_of_prediction = []
    with open(f'{output_dir}/{prefix}_label.txt', 'w', encoding='utf8') as t, \
            open(f'{output_dir}/{prefix}_predict.txt', 'w', encoding='utf8') as p, \
            open(f'{output_dir}/{prefix}_text.txt', 'w', encoding='utf8') as textf:
        with torch.no_grad():
            # predict() method returns final labels not the label_ids
            preds = predict_no_attn(batcher_test, model, idx_to_tag)
            cnt = 0
            for text, labels, predict_labels in zip(test_sentences, test_labels, preds):
                cnt += 1
                tag_labels_true = get_idx_to_tag(labels)
                text_ = text

                tag_labels_predicted = ' '.join(predict_labels)
                tag_labels_true = ' '.join(tag_labels_true)

                p.write(tag_labels_predicted + '\n')
                t.write(tag_labels_true + '\n')
                textf.write(text_ + '\n')

                tag_labels_true = tag_labels_true.strip().replace('_', '-').split()
                tag_labels_predicted = tag_labels_predicted.strip().replace('_', '-').split()
                biluo_tags_true = get_biluo(tag_labels_true)
                biluo_tags_predicted = get_biluo(tag_labels_predicted)

                doc = Doc(text_)
                offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
                offset_predicted_labels = offset_from_biluo(doc, biluo_tags_predicted)

                ent_labels = dict()
                for ent in offset_true_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    ne_class_list.add(ent_type)
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                true_labels_for_testing.append(ent_labels)

                ent_labels = dict()
                for ent in offset_predicted_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                results_of_prediction.append(ent_labels)

    from eval.quality import calculate_prediction_quality
    f1, precision, recall, results = \
        calculate_prediction_quality(true_labels_for_testing,
                                     results_of_prediction,
                                     tuple(ne_class_list))
    print(f1, precision, recall, results)
Beispiel #5
0
def train(args):
    idx_to_word, idx_to_tag, train_sentences, train_labels, test_sentences, test_labels = prepare(
        args)
    word_to_idx = {idx_to_word[key]: key for key in idx_to_word}
    tag_to_idx = {idx_to_tag[key]: key for key in idx_to_tag}

    args.vocab_size = len(idx_to_word)
    args.number_of_tags = len(idx_to_tag)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda and not args.cpu else "cpu")

    model = build_model(args)
    print(model)
    model = model.to(device)

    # optimizer = torch.optim.Adam(model.parameters())
    betas = (0.9, 0.999)
    eps = 1e-8
    optimizer = BertAdam(model,
                         lr=args.learning_rate,
                         b1=betas[0],
                         b2=betas[1],
                         e=eps)
    pad_id = word_to_idx['PAD']
    batcher = SamplingBatcher(np.asarray(train_sentences, dtype=object),
                              np.asarray(train_labels, dtype=object),
                              batch_size=args.batch_size,
                              pad_id=pad_id)

    updates = 1
    total_loss = 0
    best_loss = +inf
    stop_training = False

    output_dir = args.output_dir
    try:
        os.makedirs(output_dir)
    except:
        pass

    prefix = args.train_text.split('_')[0] if len(args.train_text.split('_')) > 1 \
        else args.train_text.split('.')[0]

    start_time = time.time()
    for epoch in range(args.epochs):
        for batch in batcher:
            updates += 1
            batch_data, batch_labels, batch_len, mask_x, mask_y = batch
            optimizer.zero_grad()
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)
            mask_y = mask_y.to(device)
            attn_mask = get_attn_pad_mask(batch_data, batch_data, pad_id)
            output_batch = model(batch_data, attn_mask)
            loss = loss_fn(output_batch, batch_labels, mask_y)

            loss.backward()
            optimizer.step()

            total_loss += loss.data
            if updates % args.patience == 0:
                print(f'Epoch: {epoch}, Updates:{updates}, Loss: {total_loss}')
                if best_loss > total_loss:
                    save_state(f'{output_dir}/{prefix}_best_model.pt', model,
                               loss_fn, optimizer, updates)
                    best_loss = total_loss
                total_loss = 0
            if updates % args.max_steps == 0:
                stop_training = True
                break

        if stop_training:
            break

    print('Training time:{}'.format(time.time() - start_time))

    def get_idx_to_tag(label_ids):
        return [idx_to_tag.get(idx) for idx in label_ids]

    def get_idx_to_word(words_ids):
        return [idx_to_word.get(idx) for idx in words_ids]

    updates = load_model_state(f'{output_dir}/{prefix}_best_model.pt', model)
    ne_class_list = set()
    true_labels_for_testing = []
    results_of_prediction = []
    with open(f'{output_dir}/{prefix}_label.txt', 'w', encoding='utf8') as t, \
            open(f'{output_dir}/{prefix}_predict.txt', 'w', encoding='utf8') as p, \
            open(f'{output_dir}/{prefix}_text.txt', 'w', encoding='utf8') as textf:
        with torch.no_grad():
            model.eval()
            cnt = 0
            for text, label in zip(test_sentences, test_labels):
                cnt += 1
                text_tensor = torch.LongTensor(text).unsqueeze(0).to(device)
                labels = torch.LongTensor(label).unsqueeze(0).to(device)
                predict = model(text_tensor)
                predict_labels = predict.argmax(dim=1)
                predict_labels = predict_labels.view(-1)
                labels = labels.view(-1)

                predicted_labels = predict_labels.cpu().data.tolist()
                true_labels = labels.cpu().data.tolist()
                tag_labels_predicted = get_idx_to_tag(predicted_labels)
                tag_labels_true = get_idx_to_tag(true_labels)
                text_ = get_idx_to_word(text)

                tag_labels_predicted = ' '.join(tag_labels_predicted)
                tag_labels_true = ' '.join(tag_labels_true)
                text_ = ' '.join(text_)
                p.write(tag_labels_predicted + '\n')
                t.write(tag_labels_true + '\n')
                textf.write(text_ + '\n')

                tag_labels_true = tag_labels_true.strip().replace('_',
                                                                  '-').split()
                tag_labels_predicted = tag_labels_predicted.strip().replace(
                    '_', '-').split()
                biluo_tags_true = get_biluo(tag_labels_true)
                biluo_tags_predicted = get_biluo(tag_labels_predicted)

                doc = Doc(text_)
                offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
                offset_predicted_labels = offset_from_biluo(
                    doc, biluo_tags_predicted)

                ent_labels = dict()
                for ent in offset_true_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    ne_class_list.add(ent_type)
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                true_labels_for_testing.append(ent_labels)

                ent_labels = dict()
                for ent in offset_predicted_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                results_of_prediction.append(ent_labels)

    from eval.quality import calculate_prediction_quality
    f1, precision, recall, results = \
        calculate_prediction_quality(true_labels_for_testing,
                                     results_of_prediction,
                                     tuple(ne_class_list))
    print(f1, precision, recall, results)
def decode(options):
    prefix = options.test_text.split('_')[0] if len(options.test_text.split('_')) > 1 \
        else options.test_text.split('.')[0]

    device = get_device(args)
    output_dir = options.output_dir
    try:
        os.makedirs(output_dir)
    except:
        pass
    model, model_args = load_model_state(options.model, device)
    model = model.to(device)

    vocab_path = os.path.join(model_args.data_dir, model_args.vocab)
    tag_path = os.path.join(model_args.data_dir, model_args.tag_set)
    word_to_idx, idx_to_word, tag_to_idx, idx_to_tag = load_vocabs(
        vocab_path, tag_path)

    *_, test_sentences, test_labels = prepare(options, word_to_idx, tag_to_idx)

    def get_idx_to_tag(label_ids):
        return [idx_to_tag.get(idx) for idx in label_ids]

    def get_idx_to_word(words_ids):
        return [idx_to_word.get(idx) for idx in words_ids]

    pad_id = word_to_idx['PAD']
    pad_id_labels = tag_to_idx['PAD']
    batcher_test = SamplingBatcher(np.asarray(test_sentences, dtype=object),
                                   np.asarray(test_labels, dtype=object),
                                   batch_size=args.batch_size,
                                   pad_id=pad_id,
                                   pad_id_labels=pad_id_labels)
    ne_class_list = set()
    true_labels_for_testing = []
    results_of_prediction = []
    with open(f'{output_dir}/{prefix}_label.txt', 'w', encoding='utf8') as t, \
            open(f'{output_dir}/{prefix}_predict.txt', 'w', encoding='utf8') as p, \
            open(f'{output_dir}/{prefix}_text.txt', 'w', encoding='utf8') as textf:
        with torch.no_grad():
            preds = predict(batcher_test, model, idx_to_tag, pad_id=pad_id)
            cnt = 0
            for text, labels, predict_labels in zip(test_sentences,
                                                    test_labels, preds):
                cnt += 1
                tag_labels_true = get_idx_to_tag(labels)
                text_ = get_idx_to_word(text)

                tag_labels_predicted = ' '.join(predict_labels)
                tag_labels_true = ' '.join(tag_labels_true)
                text_ = ' '.join(text_)
                p.write(tag_labels_predicted + '\n')
                t.write(tag_labels_true + '\n')
                textf.write(text_ + '\n')

                tag_labels_true = tag_labels_true.strip().replace('_',
                                                                  '-').split()
                tag_labels_predicted = tag_labels_predicted.strip().replace(
                    '_', '-').split()
                biluo_tags_true = get_biluo(tag_labels_true)
                biluo_tags_predicted = get_biluo(tag_labels_predicted)

                doc = Doc(text_)
                offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
                offset_predicted_labels = offset_from_biluo(
                    doc, biluo_tags_predicted)

                ent_labels = dict()
                for ent in offset_true_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    ne_class_list.add(ent_type)
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                true_labels_for_testing.append(ent_labels)

                ent_labels = dict()
                for ent in offset_predicted_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                results_of_prediction.append(ent_labels)

    from eval.quality import calculate_prediction_quality
    f1, precision, recall, results = \
        calculate_prediction_quality(true_labels_for_testing,
                                     results_of_prediction,
                                     tuple(ne_class_list))
    print(f1, precision, recall, results)
        else:
            labels_predict_all.append(labels_predict)
            labels_true_all.append(labels_true)

    true_labels_final = []
    predicted_labels_final = []
    ne_class_list = set()
    for line_id, line in enumerate(zip(text_all, labels_true_all, labels_predict_all)):
        text, true_labels, pred_labels = line
        pred_labels = [p.replace('_', '-') for p in pred_labels]
        true_labels = [t.replace('_', '-') for t in true_labels]

        biluo_tags_true = get_biluo(true_labels)
        biluo_tags_predicted = get_biluo(pred_labels)
        doc = Doc(text.strip())
        offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
        offset_predicted_labels = offset_from_biluo(doc, biluo_tags_predicted)

        ent_labels = dict()
        for ent in offset_true_labels:
            start, stop, ent_type = ent
            ent_type = ent_type.replace('_', '')
            ne_class_list.add(ent_type)
            if ent_type in ent_labels:
                ent_labels[ent_type].append((start, stop))
            else:
                ent_labels[ent_type] = [(start, stop)]
        true_labels_final.append(ent_labels)

        ent_labels = dict()
        for ent in offset_predicted_labels:
Beispiel #8
0
def train(args):
    vocab_path = os.path.join(args.data_dir, args.vocab)
    tag_path = os.path.join(args.data_dir, args.tag_set)
    word_to_idx, idx_to_word, tag_to_idx, idx_to_tag = load_vocabs(
        vocab_path, tag_path)
    train_sentences, train_labels, test_sentences, test_labels = prepare_flair(
        args, tag_to_idx)

    device = get_device(args)
    flair.device = device

    start = time.time()
    # flair_forward_embedding = FlairEmbeddings('multi-forward')
    # flair_backward_embedding = FlairEmbeddings('multi-backward')
    # init multilingual BERT
    bert_embedding = TransformerWordEmbeddings(
        'distilbert-base-multilingual-cased',
        layers='-1',
        batch_size=args.batch_size)
    # bert_embedding1 = TransformerWordEmbeddings('sentence-transformers/'
    #                                             'distilbert-multilingual-nli-stsb-quora-ranking',
    #                                             layers='-1',
    #                                             batch_size=args.batch_size)
    # bert_embedding2 = TransformerWordEmbeddings('sentence-transformers/quora-distilbert-multilingual',
    #                                             layers='-1',
    #                                             batch_size=args.batch_size)
    # now create the StackedEmbedding object that combines all embeddings
    embeddings = StackedEmbeddings(embeddings=[bert_embedding])

    # Embed words in the train and test sentence
    start_idx = 0
    n_samples = len(train_sentences)
    while start_idx < n_samples + args.batch_size:
        batch_slice = train_sentences[
            start_idx:min(start_idx + args.batch_size, n_samples)]
        start_idx += args.batch_size
        embeddings.embed(batch_slice)

    start_idx = 0
    n_samples = len(test_sentences)
    while start_idx <= n_samples + args.batch_size:
        batch_slice = test_sentences[start_idx:min(start_idx +
                                                   args.batch_size, n_samples)]
        start_idx += args.batch_size
        embeddings.embed(batch_slice)

    print(f'Encoding time:{time.time() - start}')

    # Update the Namespace
    args.vocab_size = len(idx_to_word)
    args.number_of_tags = len(idx_to_tag)

    model = build_model(args, device)
    print(model)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters())

    pad_id = word_to_idx['PAD']
    pad_id_labels = tag_to_idx['PAD']

    batcher = SamplingBatcherFlair(
        np.asarray(train_sentences, dtype=object),
        np.asarray(train_labels, dtype=object),
        batch_size=args.batch_size,
        pad_id=pad_id,
        pad_id_labels=pad_id_labels,
        embedding_length=embeddings.embedding_length)

    updates = 1
    total_loss = 0
    best_loss = +inf
    stop_training = False
    output_dir = args.output_dir
    try:
        os.makedirs(output_dir)
    except:
        pass

    prefix = args.train_text.split('_')[0] if len(args.train_text.split('_')) > 1 \
        else args.train_text.split('.')[0]

    start_time = time.time()
    for epoch in range(args.epochs):
        for batch in batcher:
            updates += 1
            input_, labels, labels_mask = batch
            optimizer.zero_grad()
            loss = model.score(batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.data
            if updates % args.patience == 0:
                print(f'Epoch: {epoch}, Updates:{updates}, Loss: {total_loss}')
                if best_loss > total_loss:
                    save_state(f'{output_dir}/{prefix}_best_model.pt',
                               model,
                               loss_fn,
                               optimizer,
                               updates,
                               args=args)
                    best_loss = total_loss
                total_loss = 0
            if updates % args.max_steps == 0:
                stop_training = True
                break

        if stop_training:
            break

    print('Training time:{}'.format(time.time() - start_time))

    def get_idx_to_tag(label_ids):
        return [idx_to_tag.get(idx) for idx in label_ids]

    def get_idx_to_word(words_ids):
        return [idx_to_word.get(idx) for idx in words_ids]

    model, model_args = load_model_state(
        f'{output_dir}/{prefix}_best_model.pt', device)
    model = model.to(device)
    batcher_test = SamplingBatcherFlair(
        np.asarray(test_sentences, dtype=object),
        np.asarray(test_labels, dtype=object),
        batch_size=args.batch_size,
        pad_id=pad_id,
        pad_id_labels=pad_id_labels,
        embedding_length=embeddings.embedding_length)
    ne_class_list = set()
    true_labels_for_testing = []
    results_of_prediction = []
    with open(f'{output_dir}/{prefix}_label.txt', 'w', encoding='utf8') as t, \
            open(f'{output_dir}/{prefix}_predict.txt', 'w', encoding='utf8') as p, \
            open(f'{output_dir}/{prefix}_text.txt', 'w', encoding='utf8') as textf:
        with torch.no_grad():
            # predict() method returns final labels not the label_ids
            preds = predict_no_attn(batcher_test, model, idx_to_tag)
            cnt = 0
            for text, labels, predict_labels in zip(test_sentences,
                                                    test_labels, preds):
                cnt += 1
                tag_labels_true = get_idx_to_tag(labels)
                text_ = text.to_original_text()

                tag_labels_predicted = ' '.join(predict_labels)
                tag_labels_true = ' '.join(tag_labels_true)

                p.write(tag_labels_predicted + '\n')
                t.write(tag_labels_true + '\n')
                textf.write(text_ + '\n')

                tag_labels_true = tag_labels_true.strip().replace('_',
                                                                  '-').split()
                tag_labels_predicted = tag_labels_predicted.strip().replace(
                    '_', '-').split()
                biluo_tags_true = get_biluo(tag_labels_true)
                biluo_tags_predicted = get_biluo(tag_labels_predicted)

                doc = Doc(text_)
                offset_true_labels = offset_from_biluo(doc, biluo_tags_true)
                offset_predicted_labels = offset_from_biluo(
                    doc, biluo_tags_predicted)

                ent_labels = dict()
                for ent in offset_true_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    ne_class_list.add(ent_type)
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                true_labels_for_testing.append(ent_labels)

                ent_labels = dict()
                for ent in offset_predicted_labels:
                    start, stop, ent_type = ent
                    ent_type = ent_type.replace('_', '')
                    if ent_type in ent_labels:
                        ent_labels[ent_type].append((start, stop))
                    else:
                        ent_labels[ent_type] = [(start, stop)]
                results_of_prediction.append(ent_labels)

    from eval.quality import calculate_prediction_quality
    f1, precision, recall, results = \
        calculate_prediction_quality(true_labels_for_testing,
                                     results_of_prediction,
                                     tuple(ne_class_list))
    print(f1, precision, recall, results)