def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer): def collate(examples): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples) ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file) sampler = SequentialSampler(dataset) dataloader = DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate ) model.to(args.device) model = delete_encoding_layers(model) model.eval() tqdm_iterator = trange(dataset.__len__(), desc="Extracting alignments") with open(args.output_file, 'w') as writer: for batch in dataloader: with torch.no_grad(): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch word_aligns_list = model.get_aligned_word(ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True) for word_aligns in word_aligns_list: output_str = [] for word_align in word_aligns: output_str.append(f'{word_align[0]}-{word_align[1]}') writer.write(' '.join(output_str)+'\n') tqdm_iterator.update(len(ids_src))
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer): def collate(examples): worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = zip(*examples) ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) return worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt offsets = find_offsets(args.data_file, args.num_workers) dataset = LineByLineTextDataset(tokenizer, file_path=args.data_file, offsets=offsets) dataloader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers ) model.to(args.device) model.eval() tqdm_iterator = trange(0, desc="Extracting") writers = open_writer_list(args.output_file, args.num_workers) if args.output_prob_file is not None: prob_writers = open_writer_list(args.output_prob_file, args.num_workers) if args.output_word_file is not None: word_writers = open_writer_list(args.output_word_file, args.num_workers) for batch in dataloader: with torch.no_grad(): worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = batch word_aligns_list = model.get_aligned_word(ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True, output_prob=(args.output_prob_file is not None)) for worker_id, word_aligns, sent_src, sent_tgt in zip(worker_ids, word_aligns_list, sents_src, sents_tgt): output_str = [] if args.output_prob_file is not None: output_prob_str = [] if args.output_word_file is not None: output_word_str = [] for word_align in word_aligns: if word_align[0] != -1: output_str.append(f'{word_align[0]}-{word_align[1]}') if args.output_prob_file is not None: output_prob_str.append(f'{word_aligns[word_align]}') if args.output_word_file is not None: output_word_str.append(f'{sent_src[word_align[0]]}<sep>{sent_tgt[word_align[1]]}') writers[worker_id].write(' '.join(output_str)+'\n') if args.output_prob_file is not None: prob_writers[worker_id].write(' '.join(output_prob_str)+'\n') if args.output_word_file is not None: word_writers[worker_id].write(' '.join(output_word_str)+'\n') tqdm_iterator.update(len(ids_src)) merge_files(writers) if args.output_prob_file is not None: merge_files(prob_writers) if args.output_word_file is not None: merge_files(word_writers)