예제 #1
0
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
    def collate(examples):
        ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples)
        ids_src = pad_sequence(ids_src,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        ids_tgt = pad_sequence(ids_tgt,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt

    dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=args.batch_size,
                            collate_fn=collate)

    model.to(args.device)
    model.eval()
    tqdm_iterator = trange(dataset.__len__(), desc="Extracting")
    with open(args.output_file, 'w') as writer:
        for batch in dataloader:
            with torch.no_grad():
                ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch
                word_aligns_list = model.get_aligned_word(
                    ids_src,
                    ids_tgt,
                    bpe2word_map_src,
                    bpe2word_map_tgt,
                    args.device,
                    0,
                    0,
                    align_layer=args.align_layer,
                    extraction=args.extraction,
                    softmax_threshold=args.softmax_threshold,
                    test=True)
                for word_aligns in word_aligns_list:
                    output_str = []
                    for word_align in word_aligns:
                        output_str.append(f'{word_align[0]}-{word_align[1]}')
                    writer.write(' '.join(output_str) + '\n')
                tqdm_iterator.update(len(ids_src))
예제 #2
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples):
        model.eval()
        examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], []
        src_len = tgt_len = 0
        bpe2word_map_src, bpe2word_map_tgt = [], []
        for example in examples:
            end_id = example[0][0][-1].view(-1)

            src_id = example[0][0][:args.block_size]
            src_id = torch.cat([src_id[:-1], end_id])
            tgt_id = example[1][0][:args.block_size]
            tgt_id = torch.cat([tgt_id[:-1], end_id])

            half_block_size = int(args.block_size / 2)
            half_src_id = example[0][0][:half_block_size]
            half_src_id = torch.cat([half_src_id[:-1], end_id])
            half_tgt_id = example[1][0][:half_block_size]
            half_tgt_id = torch.cat([half_tgt_id[:-1], end_id])

            examples_src.append(src_id)
            examples_tgt.append(tgt_id)
            src_len = max(src_len, len(src_id))
            tgt_len = max(tgt_len, len(tgt_id))

            srctgt = torch.cat([half_src_id, half_tgt_id])
            langid = torch.cat([
                torch.ones_like(half_src_id),
                torch.ones_like(half_tgt_id) * 2
            ])
            examples_srctgt.append(srctgt)
            langid_srctgt.append(langid)

            tgtsrc = torch.cat([half_tgt_id, half_src_id])
            langid = torch.cat([
                torch.ones_like(half_tgt_id),
                torch.ones_like(half_src_id) * 2
            ])
            examples_tgtsrc.append(tgtsrc)
            langid_tgtsrc.append(langid)

            # [neg, neg] pair
            neg_half_src_id = example[-2][0][:half_block_size]
            neg_half_src_id = torch.cat([neg_half_src_id[:-1], end_id])
            neg_half_tgt_id = example[-1][0][:half_block_size]
            neg_half_tgt_id = torch.cat([neg_half_tgt_id[:-1], end_id])
            if random.random() > 0.5:
                neg_srctgt = torch.cat([neg_half_src_id, neg_half_tgt_id])
            else:
                neg_srctgt = torch.cat([neg_half_tgt_id, neg_half_src_id])
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(1)

            # [pos, neg] pair
            rd = random.random()
            if rd > 0.75:
                neg_srctgt = torch.cat([half_src_id, neg_half_tgt_id])
            elif rd > 0.5:
                neg_srctgt = torch.cat([neg_half_src_id, half_tgt_id])
            elif rd > 0.25:
                neg_srctgt = torch.cat([half_tgt_id, neg_half_src_id])
            else:
                neg_srctgt = torch.cat([neg_half_tgt_id, half_src_id])
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(0)

            bpe2word_map_src.append(example[2])
            bpe2word_map_tgt.append(example[3])

        examples_src = pad_sequence(examples_src,
                                    batch_first=True,
                                    padding_value=tokenizer.pad_token_id)
        examples_tgt = pad_sequence(examples_tgt,
                                    batch_first=True,
                                    padding_value=tokenizer.pad_token_id)
        examples_srctgt = pad_sequence(examples_srctgt,
                                       batch_first=True,
                                       padding_value=tokenizer.pad_token_id)
        langid_srctgt = pad_sequence(langid_srctgt,
                                     batch_first=True,
                                     padding_value=tokenizer.pad_token_id)
        examples_tgtsrc = pad_sequence(examples_tgtsrc,
                                       batch_first=True,
                                       padding_value=tokenizer.pad_token_id)
        langid_tgtsrc = pad_sequence(langid_tgtsrc,
                                     batch_first=True,
                                     padding_value=tokenizer.pad_token_id)
        psi_examples_srctgt = pad_sequence(
            psi_examples_srctgt,
            batch_first=True,
            padding_value=tokenizer.pad_token_id)
        psi_labels = torch.tensor(psi_labels)
        guides = model.get_aligned_word(
            examples_src,
            examples_tgt,
            bpe2word_map_src,
            bpe2word_map_tgt,
            args.device,
            src_len,
            tgt_len,
            align_layer=args.align_layer,
            extraction=args.extraction,
            softmax_threshold=args.softmax_threshold)
        return examples_src, examples_tgt, guides, examples_srctgt, langid_srctgt, examples_tgtsrc, langid_tgtsrc, psi_examples_srctgt, psi_labels

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs
    if args.max_steps > 0 and args.max_steps < t_total:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if (not (any(nd in n for nd in no_decay)))
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if ((any(nd in n for nd in no_decay)))
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    # Check if continuing training from a checkpoint
    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    set_seed(args)  # Added here for reproducibility

    def backward_loss(loss, tot_loss):
        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        tot_loss += loss.item()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        return tot_loss

    tqdm_iterator = trange(int(t_total),
                           desc="Iteration",
                           disable=args.local_rank not in [-1, 0])
    for _ in range(int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):

            model.train()

            if args.train_so or args.train_co:
                inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone()
                inputs_src, inputs_tgt = inputs_src.to(
                    args.device), inputs_tgt.to(args.device)
                attention_mask_src, attention_mask_tgt = (inputs_src !=
                                                          0), (inputs_tgt != 0)
                guide = batch[2].to(args.device)
                loss = model(inputs_src=inputs_src,
                             inputs_tgt=inputs_tgt,
                             attention_mask_src=attention_mask_src,
                             attention_mask_tgt=attention_mask_tgt,
                             guide=guide,
                             align_layer=args.align_layer,
                             extraction=args.extraction,
                             softmax_threshold=args.softmax_threshold,
                             train_so=args.train_so,
                             train_co=args.train_co)
                tr_loss = backward_loss(loss, tr_loss)

            if args.train_mlm:
                inputs_src, labels_src = mask_tokens(batch[0], tokenizer, args)
                inputs_tgt, labels_tgt = mask_tokens(batch[1], tokenizer, args)
                inputs_src, inputs_tgt = inputs_src.to(
                    args.device), inputs_tgt.to(args.device)
                labels_src, labels_tgt = labels_src.to(
                    args.device), labels_tgt.to(args.device)
                loss = model(inputs_src=inputs_src, labels_src=labels_src)
                tr_loss = backward_loss(loss, tr_loss)

                loss = model(inputs_src=inputs_tgt, labels_src=labels_tgt)
                tr_loss = backward_loss(loss, tr_loss)

            if args.train_tlm:
                rand_ids = [0, 1]
                if not args.train_tlm_full:
                    rand_ids = [int(random.random() > 0.5)]
                for rand_id in rand_ids:
                    select_srctgt = batch[int(3 + rand_id * 2)]
                    select_langid = batch[int(4 + rand_id * 2)]
                    for lang_id in [1, 2]:
                        inputs_srctgt, labels_srctgt = mask_tokens(
                            select_srctgt, tokenizer, args, select_langid,
                            lang_id)
                        inputs_srctgt, labels_srctgt = inputs_srctgt.to(
                            args.device), labels_srctgt.to(args.device)
                        loss = model(inputs_src=inputs_srctgt,
                                     labels_src=labels_srctgt)
                        tr_loss = backward_loss(loss, tr_loss)

            if args.train_psi:
                loss = model(inputs_src=batch[7].to(args.device),
                             labels_psi=batch[8].to(args.device),
                             align_layer=args.align_layer + 1)
                tr_loss = backward_loss(loss, tr_loss)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                tqdm_iterator.update()

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logger.info(
                        "  Step %s. Training loss = %s", str(global_step),
                        str((tr_loss - logging_loss) / args.logging_steps))
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if global_step > t_total:
                break
        if global_step > t_total:
            break

    return global_step, tr_loss / global_step
예제 #3
0
def evaluate(args,
             model: PreTrainedModel,
             tokenizer: PreTrainedTokenizer,
             prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

    if args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    # Note that DistributedSampler samples randomly
    def collate(examples):
        model.eval()
        examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], []
        src_len = tgt_len = 0
        bpe2word_map_src, bpe2word_map_tgt = [], []
        for example in examples:
            end_id = example[0][0][-1].view(-1)

            src_id = example[0][0][:args.block_size]
            src_id = torch.cat([src_id[:-1], end_id])
            tgt_id = example[1][0][:args.block_size]
            tgt_id = torch.cat([tgt_id[:-1], end_id])

            half_block_size = int(args.block_size / 2)
            half_src_id = example[0][0][:half_block_size]
            half_src_id = torch.cat([half_src_id[:-1], end_id])
            half_tgt_id = example[1][0][:half_block_size]
            half_tgt_id = torch.cat([half_tgt_id[:-1], end_id])

            examples_src.append(src_id)
            examples_tgt.append(tgt_id)
            src_len = max(src_len, len(src_id))
            tgt_len = max(tgt_len, len(tgt_id))

            srctgt = torch.cat([half_src_id, half_tgt_id])
            langid = torch.cat([
                torch.ones_like(half_src_id),
                torch.ones_like(half_tgt_id) * 2
            ])
            examples_srctgt.append(srctgt)
            langid_srctgt.append(langid)

            tgtsrc = torch.cat([half_tgt_id, half_src_id])
            langid = torch.cat([
                torch.ones_like(half_tgt_id),
                torch.ones_like(half_src_id) * 2
            ])
            examples_tgtsrc.append(tgtsrc)
            langid_tgtsrc.append(langid)

            # [neg, neg] pair
            neg_half_src_id = example[-2][0][:half_block_size]
            neg_half_src_id = torch.cat([neg_half_src_id[:-1], end_id])
            neg_half_tgt_id = example[-1][0][:half_block_size]
            neg_half_tgt_id = torch.cat([neg_half_tgt_id[:-1], end_id])
            neg_srctgt = torch.cat([neg_half_src_id, neg_half_tgt_id])
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(1)

            # [pos, neg] pair
            neg_srctgt = torch.cat([half_src_id, neg_half_tgt_id])
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(0)

            bpe2word_map_src.append(example[2])
            bpe2word_map_tgt.append(example[3])

        examples_src = pad_sequence(examples_src,
                                    batch_first=True,
                                    padding_value=tokenizer.pad_token_id)
        examples_tgt = pad_sequence(examples_tgt,
                                    batch_first=True,
                                    padding_value=tokenizer.pad_token_id)
        examples_srctgt = pad_sequence(examples_srctgt,
                                       batch_first=True,
                                       padding_value=tokenizer.pad_token_id)
        langid_srctgt = pad_sequence(langid_srctgt,
                                     batch_first=True,
                                     padding_value=tokenizer.pad_token_id)
        examples_tgtsrc = pad_sequence(examples_tgtsrc,
                                       batch_first=True,
                                       padding_value=tokenizer.pad_token_id)
        langid_tgtsrc = pad_sequence(langid_tgtsrc,
                                     batch_first=True,
                                     padding_value=tokenizer.pad_token_id)
        psi_examples_srctgt = pad_sequence(
            psi_examples_srctgt,
            batch_first=True,
            padding_value=tokenizer.pad_token_id)
        psi_labels = torch.tensor(psi_labels)
        guides = model.get_aligned_word(
            examples_src,
            examples_tgt,
            bpe2word_map_src,
            bpe2word_map_tgt,
            args.device,
            src_len,
            tgt_len,
            align_layer=args.align_layer,
            extraction=args.extraction,
            softmax_threshold=args.softmax_threshold)
        return examples_src, examples_tgt, guides, examples_srctgt, langid_srctgt, examples_tgtsrc, langid_tgtsrc, psi_examples_srctgt, psi_labels

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size,
                                 collate_fn=collate)

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()
    set_seed(args)  # Added here for reproducibility

    def post_loss(loss, tot_loss):
        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        tot_loss += loss.item()
        return tot_loss

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        with torch.no_grad():
            if args.train_so or args.train_co:
                inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone()
                inputs_src, inputs_tgt = inputs_src.to(
                    args.device), inputs_tgt.to(args.device)
                attention_mask_src, attention_mask_tgt = (inputs_src !=
                                                          0), (inputs_tgt != 0)
                guide = batch[2].to(args.device)
                loss = model(inputs_src=inputs_src,
                             inputs_tgt=inputs_tgt,
                             attention_mask_src=attention_mask_src,
                             attention_mask_tgt=attention_mask_tgt,
                             guide=guide,
                             align_layer=args.align_layer,
                             extraction=args.extraction,
                             softmax_threshold=args.softmax_threshold,
                             train_so=args.train_so,
                             train_co=args.train_co)
                eval_loss = post_loss(loss, eval_loss)

            if args.train_mlm:
                inputs_src, labels_src = mask_tokens(batch[0], tokenizer, args)
                inputs_tgt, labels_tgt = mask_tokens(batch[1], tokenizer, args)
                inputs_src, inputs_tgt = inputs_src.to(
                    args.device), inputs_tgt.to(args.device)
                labels_src, labels_tgt = labels_src.to(
                    args.device), labels_tgt.to(args.device)
                loss = model(inputs_src=inputs_src, labels_src=labels_src)
                eval_loss = post_loss(loss, eval_loss)

                loss = model(inputs_src=inputs_tgt, labels_src=labels_tgt)
                eval_loss = post_loss(loss, eval_loss)

            if args.train_tlm:
                select_ids = [0, 1]
                if not args.train_tlm_full:
                    select_ids = [0]
                for select_id in select_ids:
                    for lang_id in [1, 2]:
                        inputs_srctgt, labels_srctgt = mask_tokens(
                            batch[3 + select_id * 2], tokenizer, args,
                            batch[4 + select_id * 2], lang_id)
                        inputs_srctgt, labels_srctgt = inputs_srctgt.to(
                            args.device), labels_srctgt.to(args.device)
                        loss = model(inputs_src=inputs_srctgt,
                                     labels_src=labels_srctgt)
                        eval_loss = post_loss(loss, eval_loss)

            if args.train_psi:
                loss = model(inputs_src=batch[7].to(args.device),
                             labels_psi=batch[8].to(args.device))
                eval_loss = post_loss(loss, eval_loss)

        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity}

    output_eval_file = os.path.join(eval_output_dir, prefix,
                                    "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result
예제 #4
0
def word_align(args,
               model: PreTrainedModel,
               tokenizer: PreTrainedTokenizer,
               output_word_alignments=False):
    def collate(examples):
        ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples)
        ids_src = pad_sequence(ids_src,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        ids_tgt = pad_sequence(ids_tgt,
                               batch_first=True,
                               padding_value=tokenizer.pad_token_id)
        return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt

    dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=args.batch_size,
                            collate_fn=collate)

    model.to(args.device)
    model.eval()
    tqdm_iterator = trange(dataset.__len__(), desc="Extracting")
    with open(args.output_file, 'w') as writer:
        for batch in dataloader:
            with torch.no_grad():
                ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch
                word_aligns_list = model.get_aligned_word(
                    ids_src,
                    ids_tgt,
                    bpe2word_map_src,
                    bpe2word_map_tgt,
                    args.device,
                    0,
                    0,
                    align_layer=args.align_layer,
                    extraction=args.extraction,
                    softmax_threshold=args.softmax_threshold,
                    test=True)
                for word_aligns in word_aligns_list:
                    output_str = []
                    for word_align in word_aligns:
                        output_str.append(f'{word_align[0]}-{word_align[1]}')
                    writer.write(' '.join(output_str) + '\n')
                tqdm_iterator.update(len(ids_src))

    if output_word_alignments:
        with open(args.output_file, 'r') as fh:
            outputf = (fh.read()).split("\n")
        with open(args.data_file, 'r') as fh:
            datalines = (fh.read()).split("\n")

        with open(args.output_file + ".outtxt", 'w') as fwriter:
            for indices, line in zip(outputf, datalines):
                srcline, tgtline = line.split(' ||| ')
                indices = indices.split()
                srcwrds = srcline.split()
                tgtwrds = tgtline.split()
                output_wrds = []
                for wrd in indices:
                    srcix, tgtix = wrd.split("-")
                    srcix, tgtix = int(srcix), int(tgtix)
                    output_wrds.append(f"{srcwrds[srcix]}-{tgtwrds[tgtix]}")
                fwriter.write(' '.join(output_wrds) + '\n')