Example #1
0
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):

    def collate(examples):
        ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = zip(*examples)
        ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id)
        ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt

    dataset = LineByLineTextDataset(tokenizer, args, file_path=args.data_file)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate
    )

    model.to(args.device)
    model = delete_encoding_layers(model)
    model.eval()
    tqdm_iterator = trange(dataset.__len__(), desc="Extracting alignments")
    with open(args.output_file, 'w') as writer:
        for batch in dataloader:
            with torch.no_grad():
                ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt = batch
                word_aligns_list = model.get_aligned_word(ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True)
                for word_aligns in word_aligns_list:
                    output_str = []
                    for word_align in word_aligns:
                        output_str.append(f'{word_align[0]}-{word_align[1]}')
                    writer.write(' '.join(output_str)+'\n')
                tqdm_iterator.update(len(ids_src))
Example #2
0
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):

    def collate(examples):
        worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = zip(*examples)
        ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id)
        ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        return worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt

    offsets = find_offsets(args.data_file, args.num_workers)
    dataset = LineByLineTextDataset(tokenizer, file_path=args.data_file, offsets=offsets)
    dataloader = DataLoader(
        dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers
    )

    model.to(args.device)
    model.eval()
    tqdm_iterator = trange(0, desc="Extracting")

    writers = open_writer_list(args.output_file, args.num_workers) 
    if args.output_prob_file is not None:
        prob_writers = open_writer_list(args.output_prob_file, args.num_workers)
    if args.output_word_file is not None:
        word_writers = open_writer_list(args.output_word_file, args.num_workers)

    for batch in dataloader:
        with torch.no_grad():
            worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = batch
            word_aligns_list = model.get_aligned_word(ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 0, 0, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, test=True, output_prob=(args.output_prob_file is not None))
            for worker_id, word_aligns, sent_src, sent_tgt in zip(worker_ids, word_aligns_list, sents_src, sents_tgt):
                output_str = []
                if args.output_prob_file is not None:
                    output_prob_str = []
                if args.output_word_file is not None:
                    output_word_str = []
                for word_align in word_aligns:
                    if word_align[0] != -1:
                        output_str.append(f'{word_align[0]}-{word_align[1]}')
                        if args.output_prob_file is not None:
                            output_prob_str.append(f'{word_aligns[word_align]}')
                        if args.output_word_file is not None:
                            output_word_str.append(f'{sent_src[word_align[0]]}<sep>{sent_tgt[word_align[1]]}')
                writers[worker_id].write(' '.join(output_str)+'\n')
                if args.output_prob_file is not None:
                    prob_writers[worker_id].write(' '.join(output_prob_str)+'\n')
                if args.output_word_file is not None:
                    word_writers[worker_id].write(' '.join(output_word_str)+'\n')
            tqdm_iterator.update(len(ids_src))

    merge_files(writers)
    if args.output_prob_file is not None:
        merge_files(prob_writers)
    if args.output_word_file is not None:
        merge_files(word_writers)
Example #3
0
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

    if args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    def collate(examples):
        model.eval()
        examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], []
        src_len = tgt_len = 0
        bpe2word_map_src, bpe2word_map_tgt = [], []
        for example in examples:
            end_id = example[0][0][-1].view(-1)

            src_id = example[0][0][:args.block_size]
            src_id = torch.cat([src_id[:-1], end_id])
            tgt_id = example[1][0][:args.block_size]
            tgt_id = torch.cat([tgt_id[:-1], end_id])

            half_block_size = int(args.block_size/2)
            half_src_id = example[0][0][:half_block_size]
            half_src_id = torch.cat([half_src_id[:-1], end_id])
            half_tgt_id = example[1][0][:half_block_size]
            half_tgt_id = torch.cat([half_tgt_id[:-1], end_id])

            examples_src.append(src_id)
            examples_tgt.append(tgt_id)
            src_len = max(src_len, len(src_id))
            tgt_len = max(tgt_len, len(tgt_id))

            srctgt = torch.cat( [half_src_id, half_tgt_id] )
            langid = torch.cat([ torch.ones_like(half_src_id), torch.ones_like(half_tgt_id)*2] )
            examples_srctgt.append(srctgt)
            langid_srctgt.append(langid)

            tgtsrc = torch.cat( [half_tgt_id, half_src_id] )
            langid = torch.cat([ torch.ones_like(half_tgt_id), torch.ones_like(half_src_id)*2] )
            examples_tgtsrc.append(tgtsrc)
            langid_tgtsrc.append(langid)

            # [neg, neg] pair
            neg_half_src_id = example[-2][0][:half_block_size]
            neg_half_src_id = torch.cat([neg_half_src_id[:-1], end_id])
            neg_half_tgt_id = example[-1][0][:half_block_size]
            neg_half_tgt_id = torch.cat([neg_half_tgt_id[:-1], end_id])
            neg_srctgt = torch.cat( [neg_half_src_id, neg_half_tgt_id] )
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(1)
                
            # [pos, neg] pair
            neg_srctgt = torch.cat([half_src_id, neg_half_tgt_id])
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(0)

            bpe2word_map_src.append(example[2])
            bpe2word_map_tgt.append(example[3])

            
        examples_src = pad_sequence(examples_src, batch_first=True, padding_value=tokenizer.pad_token_id)
        examples_tgt = pad_sequence(examples_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        examples_srctgt = pad_sequence(examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        langid_srctgt = pad_sequence(langid_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        examples_tgtsrc = pad_sequence(examples_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id)
        langid_tgtsrc = pad_sequence(langid_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id)
        psi_examples_srctgt = pad_sequence(psi_examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        psi_labels = torch.tensor(psi_labels)
        guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold)

        return examples_src, examples_tgt, guides, examples_srctgt, langid_srctgt, examples_tgtsrc, langid_tgtsrc, psi_examples_srctgt, psi_labels

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
    )

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()
    set_seed(args)  # Added here for reproducibility

    def post_loss(loss, tot_loss): 
        if args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        tot_loss += loss.item()
        return tot_loss

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        with torch.no_grad():
            if args.train_so or args.train_co:
                inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone()
                inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device)
                attention_mask_src, attention_mask_tgt = (inputs_src!=0), (inputs_tgt!=0)
                guide = batch[2].to(args.device)
                loss = model(inputs_src=inputs_src, inputs_tgt=inputs_tgt, attention_mask_src=attention_mask_src, attention_mask_tgt=attention_mask_tgt, guide=guide, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, train_so=args.train_so, train_co=args.train_co)
                eval_loss = post_loss(loss, eval_loss)

            if args.train_mlm:
                inputs_src, labels_src = mask_tokens(batch[0], tokenizer, args)
                inputs_tgt, labels_tgt = mask_tokens(batch[1], tokenizer, args)
                inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device)
                labels_src, labels_tgt = labels_src.to(args.device), labels_tgt.to(args.device)
                loss = model(inputs_src=inputs_src, labels_src=labels_src)
                eval_loss = post_loss(loss, eval_loss)

                loss = model(inputs_src=inputs_tgt, labels_src=labels_tgt)
                eval_loss = post_loss(loss, eval_loss)

            if args.train_tlm:
                select_ids = [0, 1]
                if not args.train_tlm_full:
                    select_ids = [0]
                for select_id in select_ids:
                    for lang_id in [1, 2]:
                        inputs_srctgt, labels_srctgt = mask_tokens(batch[3+select_id*2], tokenizer, args, batch[4+select_id*2], lang_id)
                        inputs_srctgt, labels_srctgt = inputs_srctgt.to(args.device), labels_srctgt.to(args.device)
                        loss = model(inputs_src=inputs_srctgt, labels_src=labels_srctgt)
                        eval_loss = post_loss(loss, eval_loss)

            if args.train_psi:
                loss = model(inputs_src=batch[7].to(args.device), labels_psi=batch[8].to(args.device))
                eval_loss = post_loss(loss, eval_loss)

        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity}

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result