コード例 #1
0
ファイル: run_align.py プロジェクト: copperdong/awesome-align
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
    def collate(examples):
        ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples)
        ids_src = pad_sequence(ids_src,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        ids_tgt = pad_sequence(ids_tgt,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt

    dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=args.batch_size,
                            collate_fn=collate)

    model.to(args.device)
    model.eval()
    tqdm_iterator = trange(dataset.__len__(), desc="Extracting")
    with open(args.output_file, 'w') as writer:
        for batch in dataloader:
            with torch.no_grad():
                ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch
                word_aligns_list = model.get_aligned_word(
                    ids_src,
                    ids_tgt,
                    bpe2word_map_src,
                    bpe2word_map_tgt,
                    args.device,
                    0,
                    0,
                    align_layer=args.align_layer,
                    extraction=args.extraction,
                    softmax_threshold=args.softmax_threshold,
                    test=True)
                for word_aligns in word_aligns_list:
                    output_str = []
                    for word_align in word_aligns:
                        output_str.append(f'{word_align[0]}-{word_align[1]}')
                    writer.write(' '.join(output_str) + '\n')
                tqdm_iterator.update(len(ids_src))
コード例 #2
0
def word_align(args,
               model: PreTrainedModel,
               tokenizer: PreTrainedTokenizer,
               output_word_alignments=False):
    def collate(examples):
        ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples)
        ids_src = pad_sequence(ids_src,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        ids_tgt = pad_sequence(ids_tgt,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt

    dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=args.batch_size,
                            collate_fn=collate)

    model.to(args.device)
    model.eval()
    tqdm_iterator = trange(dataset.__len__(), desc="Extracting")
    with open(args.output_file, 'w') as writer:
        for batch in dataloader:
            with torch.no_grad():
                ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch
                word_aligns_list = model.get_aligned_word(
                    ids_src,
                    ids_tgt,
                    bpe2word_map_src,
                    bpe2word_map_tgt,
                    args.device,
                    0,
                    0,
                    align_layer=args.align_layer,
                    extraction=args.extraction,
                    softmax_threshold=args.softmax_threshold,
                    test=True)
                for word_aligns in word_aligns_list:
                    output_str = []
                    for word_align in word_aligns:
                        output_str.append(f'{word_align[0]}-{word_align[1]}')
                    writer.write(' '.join(output_str) + '\n')
                tqdm_iterator.update(len(ids_src))

    if output_word_alignments:
        with open(args.output_file, 'r') as fh:
            outputf = (fh.read()).split("\n")
        with open(args.data_file, 'r') as fh:
            datalines = (fh.read()).split("\n")

        with open(args.output_file + ".outtxt", 'w') as fwriter:
            for indices, line in zip(outputf, datalines):
                srcline, tgtline = line.split(' ||| ')
                indices = indices.split()
                srcwrds = srcline.split()
                tgtwrds = tgtline.split()
                output_wrds = []
                for wrd in indices:
                    srcix, tgtix = wrd.split("-")
                    srcix, tgtix = int(srcix), int(tgtix)
                    output_wrds.append(f"{srcwrds[srcix]}-{tgtwrds[tgtix]}")
                fwriter.write(' '.join(output_wrds) + '\n')