def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer): def collate(examples): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples) ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file) sampler = SequentialSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate) model.to(args.device) model.eval() tqdm_iterator = trange(dataset.__len__(), desc="Extracting") with open(args.output_file, 'w') as writer: for batch in dataloader: with torch.no_grad(): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch word_aligns_list = model.get_aligned_word( ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True) for word_aligns in word_aligns_list: output_str = [] for word_align in word_aligns: output_str.append(f'{word_align[0]}-{word_align[1]}') writer.write(' '.join(output_str) + '\n') tqdm_iterator.update(len(ids_src))
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, output_word_alignments=False): def collate(examples): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples) ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file) sampler = SequentialSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate) model.to(args.device) model.eval() tqdm_iterator = trange(dataset.__len__(), desc="Extracting") with open(args.output_file, 'w') as writer: for batch in dataloader: with torch.no_grad(): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch word_aligns_list = model.get_aligned_word( ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True) for word_aligns in word_aligns_list: output_str = [] for word_align in word_aligns: output_str.append(f'{word_align[0]}-{word_align[1]}') writer.write(' '.join(output_str) + '\n') tqdm_iterator.update(len(ids_src)) if output_word_alignments: with open(args.output_file, 'r') as fh: outputf = (fh.read()).split("\n") with open(args.data_file, 'r') as fh: datalines = (fh.read()).split("\n") with open(args.output_file + ".outtxt", 'w') as fwriter: for indices, line in zip(outputf, datalines): srcline, tgtline = line.split(' ||| ') indices = indices.split() srcwrds = srcline.split() tgtwrds = tgtline.split() output_wrds = [] for wrd in indices: srcix, tgtix = wrd.split("-") srcix, tgtix = int(srcix), int(tgtix) output_wrds.append(f"{srcwrds[srcix]}-{tgtwrds[tgtix]}") fwriter.write(' '.join(output_wrds) + '\n')