def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer): def collate(examples): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples) ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file) sampler = SequentialSampler(dataset) dataloader = DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate ) model.to(args.device) model = delete_encoding_layers(model) model.eval() tqdm_iterator = trange(dataset.__len__(), desc="Extracting alignments") with open(args.output_file, 'w') as writer: for batch in dataloader: with torch.no_grad(): ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch word_aligns_list = model.get_aligned_word(ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True) for word_aligns in word_aligns_list: output_str = [] for word_align in word_aligns: output_str.append(f'{word_align[0]}-{word_align[1]}') writer.write(' '.join(output_str)+'\n') tqdm_iterator.update(len(ids_src))
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer): def collate(examples): worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = zip(*examples) ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) return worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt offsets = find_offsets(args.data_file, args.num_workers) dataset = LineByLineTextDataset(tokenizer, file_path=args.data_file, offsets=offsets) dataloader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers ) model.to(args.device) model.eval() tqdm_iterator = trange(0, desc="Extracting") writers = open_writer_list(args.output_file, args.num_workers) if args.output_prob_file is not None: prob_writers = open_writer_list(args.output_prob_file, args.num_workers) if args.output_word_file is not None: word_writers = open_writer_list(args.output_word_file, args.num_workers) for batch in dataloader: with torch.no_grad(): worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = batch word_aligns_list = model.get_aligned_word(ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True, output_prob=(args.output_prob_file is not None)) for worker_id, word_aligns, sent_src, sent_tgt in zip(worker_ids, word_aligns_list, sents_src, sents_tgt): output_str = [] if args.output_prob_file is not None: output_prob_str = [] if args.output_word_file is not None: output_word_str = [] for word_align in word_aligns: if word_align[0] != -1: output_str.append(f'{word_align[0]}-{word_align[1]}') if args.output_prob_file is not None: output_prob_str.append(f'{word_aligns[word_align]}') if args.output_word_file is not None: output_word_str.append(f'{sent_src[word_align[0]]}<sep>{sent_tgt[word_align[1]]}') writers[worker_id].write(' '.join(output_str)+'\n') if args.output_prob_file is not None: prob_writers[worker_id].write(' '.join(output_prob_str)+'\n') if args.output_word_file is not None: word_writers[worker_id].write(' '.join(output_word_str)+'\n') tqdm_iterator.update(len(ids_src)) merge_files(writers) if args.output_prob_file is not None: merge_files(prob_writers) if args.output_word_file is not None: merge_files(word_writers)
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict: # Loop to handle MNLI double evaluation (matched, mis-matched) eval_output_dir = args.output_dir eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) if args.local_rank in [-1, 0]: os.makedirs(eval_output_dir, exist_ok=True) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly def collate(examples): model.eval() examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], [] src_len = tgt_len = 0 bpe2word_map_src, bpe2word_map_tgt = [], [] for example in examples: end_id = example[0][0][-1].view(-1) src_id = example[0][0][:args.block_size] src_id = torch.cat([src_id[:-1], end_id]) tgt_id = example[1][0][:args.block_size] tgt_id = torch.cat([tgt_id[:-1], end_id]) half_block_size = int(args.block_size/2) half_src_id = example[0][0][:half_block_size] half_src_id = torch.cat([half_src_id[:-1], end_id]) half_tgt_id = example[1][0][:half_block_size] half_tgt_id = torch.cat([half_tgt_id[:-1], end_id]) examples_src.append(src_id) examples_tgt.append(tgt_id) src_len = max(src_len, len(src_id)) tgt_len = max(tgt_len, len(tgt_id)) srctgt = torch.cat( [half_src_id, half_tgt_id] ) langid = torch.cat([ torch.ones_like(half_src_id), torch.ones_like(half_tgt_id)*2] ) examples_srctgt.append(srctgt) langid_srctgt.append(langid) tgtsrc = torch.cat( [half_tgt_id, half_src_id] ) langid = torch.cat([ torch.ones_like(half_tgt_id), torch.ones_like(half_src_id)*2] ) examples_tgtsrc.append(tgtsrc) langid_tgtsrc.append(langid) # [neg, neg] pair neg_half_src_id = example[-2][0][:half_block_size] neg_half_src_id = torch.cat([neg_half_src_id[:-1], end_id]) neg_half_tgt_id = example[-1][0][:half_block_size] neg_half_tgt_id = torch.cat([neg_half_tgt_id[:-1], end_id]) neg_srctgt = torch.cat( [neg_half_src_id, neg_half_tgt_id] ) psi_examples_srctgt.append(neg_srctgt) psi_labels.append(1) # [pos, neg] pair neg_srctgt = torch.cat([half_src_id, neg_half_tgt_id]) psi_examples_srctgt.append(neg_srctgt) psi_labels.append(0) bpe2word_map_src.append(example[2]) bpe2word_map_tgt.append(example[3]) examples_src = pad_sequence(examples_src, batch_first=True, padding_value=tokenizer.pad_token_id) examples_tgt = pad_sequence(examples_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) examples_srctgt = pad_sequence(examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id) langid_srctgt = pad_sequence(langid_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id) examples_tgtsrc = pad_sequence(examples_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id) langid_tgtsrc = pad_sequence(langid_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id) psi_examples_srctgt = pad_sequence(psi_examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id) psi_labels = torch.tensor(psi_labels) guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold) return examples_src, examples_tgt, guides, examples_srctgt, langid_srctgt, examples_tgtsrc, langid_tgtsrc, psi_examples_srctgt, psi_labels eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader( eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate ) # multi-gpu evaluate if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 model.eval() set_seed(args) # Added here for reproducibility def post_loss(loss, tot_loss): if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training tot_loss += loss.item() return tot_loss for batch in tqdm(eval_dataloader, desc="Evaluating"): with torch.no_grad(): if args.train_so or args.train_co: inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone() inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device) attention_mask_src, attention_mask_tgt = (inputs_src!=0), (inputs_tgt!=0) guide = batch[2].to(args.device) loss = model(inputs_src=inputs_src, inputs_tgt=inputs_tgt, attention_mask_src=attention_mask_src, attention_mask_tgt=attention_mask_tgt, guide=guide, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, train_so=args.train_so, train_co=args.train_co) eval_loss = post_loss(loss, eval_loss) if args.train_mlm: inputs_src, labels_src = mask_tokens(batch[0], tokenizer, args) inputs_tgt, labels_tgt = mask_tokens(batch[1], tokenizer, args) inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device) labels_src, labels_tgt = labels_src.to(args.device), labels_tgt.to(args.device) loss = model(inputs_src=inputs_src, labels_src=labels_src) eval_loss = post_loss(loss, eval_loss) loss = model(inputs_src=inputs_tgt, labels_src=labels_tgt) eval_loss = post_loss(loss, eval_loss) if args.train_tlm: select_ids = [0, 1] if not args.train_tlm_full: select_ids = [0] for select_id in select_ids: for lang_id in [1, 2]: inputs_srctgt, labels_srctgt = mask_tokens(batch[3+select_id*2], tokenizer, args, batch[4+select_id*2], lang_id) inputs_srctgt, labels_srctgt = inputs_srctgt.to(args.device), labels_srctgt.to(args.device) loss = model(inputs_src=inputs_srctgt, labels_src=labels_srctgt) eval_loss = post_loss(loss, eval_loss) if args.train_psi: loss = model(inputs_src=batch[7].to(args.device), labels_psi=batch[8].to(args.device)) eval_loss = post_loss(loss, eval_loss) nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps perplexity = torch.exp(torch.tensor(eval_loss)) result = {"perplexity": perplexity} output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return result
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) def collate(examples): model.eval() examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], [] src_len = tgt_len = 0 bpe2word_map_src, bpe2word_map_tgt = [], [] for example in examples: end_id = example[0][0][-1].view(-1) src_id = example[0][0][:args.block_size] src_id = torch.cat([src_id[:-1], end_id]) tgt_id = example[1][0][:args.block_size] tgt_id = torch.cat([tgt_id[:-1], end_id]) half_block_size = int(args.block_size/2) half_src_id = example[0][0][:half_block_size] half_src_id = torch.cat([half_src_id[:-1], end_id]) half_tgt_id = example[1][0][:half_block_size] half_tgt_id = torch.cat([half_tgt_id[:-1], end_id]) examples_src.append(src_id) examples_tgt.append(tgt_id) src_len = max(src_len, len(src_id)) tgt_len = max(tgt_len, len(tgt_id)) srctgt = torch.cat( [half_src_id, half_tgt_id] ) langid = torch.cat([ torch.ones_like(half_src_id), torch.ones_like(half_tgt_id)*2] ) examples_srctgt.append(srctgt) langid_srctgt.append(langid) tgtsrc = torch.cat( [half_tgt_id, half_src_id]) langid = torch.cat([ torch.ones_like(half_tgt_id), torch.ones_like(half_src_id)*2] ) examples_tgtsrc.append(tgtsrc) langid_tgtsrc.append(langid) # [neg, neg] pair neg_half_src_id = example[-2][0][:half_block_size] neg_half_src_id = torch.cat([neg_half_src_id[:-1], end_id]) neg_half_tgt_id = example[-1][0][:half_block_size] neg_half_tgt_id = torch.cat([neg_half_tgt_id[:-1], end_id]) if random.random()> 0.5: neg_srctgt = torch.cat( [neg_half_src_id, neg_half_tgt_id] ) else: neg_srctgt = torch.cat( [neg_half_tgt_id, neg_half_src_id] ) psi_examples_srctgt.append(neg_srctgt) psi_labels.append(1) # [pos, neg] pair rd = random.random() if rd> 0.75: neg_srctgt = torch.cat([half_src_id, neg_half_tgt_id]) elif rd > 0.5: neg_srctgt = torch.cat([neg_half_src_id, half_tgt_id]) elif rd > 0.25: neg_srctgt = torch.cat([half_tgt_id, neg_half_src_id]) else: neg_srctgt = torch.cat([neg_half_tgt_id, half_src_id]) psi_examples_srctgt.append(neg_srctgt) psi_labels.append(0) bpe2word_map_src.append(example[2]) bpe2word_map_tgt.append(example[3]) examples_src = pad_sequence(examples_src, batch_first=True, padding_value=tokenizer.pad_token_id) examples_tgt = pad_sequence(examples_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) examples_srctgt = pad_sequence(examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id) langid_srctgt = pad_sequence(langid_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id) examples_tgtsrc = pad_sequence(examples_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id) langid_tgtsrc = pad_sequence(langid_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id) psi_examples_srctgt = pad_sequence(psi_examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id) psi_labels = torch.tensor(psi_labels) if args.n_gpu > 1 or args.local_rank != -1: guides = model.module.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold) else: guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold) return examples_src, examples_tgt, guides, examples_srctgt, langid_srctgt, examples_tgtsrc, langid_tgtsrc, psi_examples_srctgt, psi_labels train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate ) t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs if args.max_steps > 0 and args.max_steps < t_total: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if ( not (any(nd in n for nd in no_decay)) )], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if ( (any(nd in n for nd in no_decay)) )], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total ) if args.fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 # Check if continuing training from a checkpoint tr_loss, logging_loss = 0.0, 0.0 model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training model_to_resize.resize_token_embeddings(len(tokenizer)) model.zero_grad() set_seed(args) # Added here for reproducibility def backward_loss(loss, tot_loss): if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps tot_loss += loss.item() if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() return tot_loss tqdm_iterator = trange(int(t_total), desc="Iteration", disable=args.local_rank not in [-1, 0]) for _ in range(int(args.num_train_epochs)): for step, batch in enumerate(train_dataloader): model.train() if args.train_so or args.train_co: inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone() inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device) attention_mask_src, attention_mask_tgt = (inputs_src!=0), (inputs_tgt!=0) guide = batch[2].to(args.device) loss = model(inputs_src=inputs_src, inputs_tgt=inputs_tgt, attention_mask_src=attention_mask_src, attention_mask_tgt=attention_mask_tgt, guide=guide, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, train_so=args.train_so, train_co=args.train_co) tr_loss = backward_loss(loss, tr_loss) if args.train_mlm: inputs_src, labels_src = mask_tokens(batch[0], tokenizer, args) inputs_tgt, labels_tgt = mask_tokens(batch[1], tokenizer, args) inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device) labels_src, labels_tgt = labels_src.to(args.device), labels_tgt.to(args.device) loss = model(inputs_src=inputs_src, labels_src=labels_src) tr_loss = backward_loss(loss, tr_loss) loss = model(inputs_src=inputs_tgt, labels_src=labels_tgt) tr_loss = backward_loss(loss, tr_loss) if args.train_tlm: rand_ids = [0, 1] if not args.train_tlm_full: rand_ids = [int(random.random() > 0.5)] for rand_id in rand_ids: select_srctgt = batch[int(3+rand_id*2)] select_langid = batch[int(4+rand_id*2)] for lang_id in [1, 2]: inputs_srctgt, labels_srctgt = mask_tokens(select_srctgt, tokenizer, args, select_langid, lang_id) inputs_srctgt, labels_srctgt = inputs_srctgt.to(args.device), labels_srctgt.to(args.device) loss = model(inputs_src=inputs_srctgt, labels_src=labels_srctgt) tr_loss = backward_loss(loss, tr_loss) if args.train_psi: loss = model(inputs_src=batch[7].to(args.device), labels_psi=batch[8].to(args.device), align_layer=args.align_layer+1) tr_loss = backward_loss(loss, tr_loss) if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 tqdm_iterator.update() if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logger.info(" Step %s. Training loss = %s", str(global_step), str((tr_loss-logging_loss)/args.logging_steps)) logging_loss = tr_loss if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: checkpoint_prefix = "checkpoint" # Save model checkpoint output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) os.makedirs(output_dir, exist_ok=True) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) _rotate_checkpoints(args, checkpoint_prefix) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if global_step > t_total: break if global_step > t_total: break return global_step, tr_loss / global_step