Ejemplo n.º 1
0
def multi_process_main(
    args: Any,
    use_output_queue: bool,
    start_rank: int = 0,
    init_fn: Optional[Callable[[], None]] = None,
):
    pytorch_translate_options.print_args(args)
    torch_mp = torch.multiprocessing.get_context("spawn")

    # Create a thread to listen for errors in the child processes.
    error_queue = torch_mp.SimpleQueue()
    error_handler = pytorch_translate_utils.ErrorHandler(error_queue)
    # SimpleQueue doesn't seem to work for output_queue since it seems to block
    # on put() if the parent process doesn't get() the results?
    output_queue = torch_mp.Queue() if use_output_queue else None

    # Train with multiprocessing.
    processes = []
    for i in range(torch.cuda.device_count()):
        args.distributed_rank = start_rank + i
        args.device_id = i
        processes.append(
            torch_mp.Process(
                target=multi_process_train,
                args=(args, error_queue, output_queue, init_fn),
                daemon=True,
            ))
        processes[i].start()
        error_handler.add_child(processes[i].pid)

    return (processes, error_handler, output_queue)
Ejemplo n.º 2
0
def multi_process_main(
    args: Any,
    start_rank: int = 0,
    init_fn: Optional[Callable[[], None]] = None,
    trainer_class=None,
    **train_step_kwargs,
):
    pytorch_translate_options.print_args(args)
    output_queue = torch.multiprocessing.get_context("spawn").Queue()
    # Train with multiprocessing.
    spawn_context = torch.multiprocessing.spawn(
        fn=multi_process_train,
        args=(
            args,
            output_queue,
            start_rank,
            init_fn,
            trainer_class,
            train_step_kwargs,
        ),
        nprocs=args.local_num_gpus,
        # We don't block here to allow caller to process output_queue in
        # parallel with training.
        join=False,
    )
    return (spawn_context, output_queue)
Ejemplo n.º 3
0
def save_top_k(args):
    """
    This function runs forward computation on an ensemble of trained models
    using binarized parallel training data and saves the top-k probabilities
    and their corresponding token indices for each output step.

    Note that the Python binary accepts all generation params, but ignores
    inapplicable ones (such as those related to output length). --max-tokens
    is of particular importance to prevent memory errors.
    """
    pytorch_translate_options.print_args(args)
    use_cuda = torch.cuda.is_available() and not getattr(args, "cpu", False)

    (
        models,
        model_args,
        task,
    ) = pytorch_translate_utils.load_diverse_ensemble_for_inference(
        args.path.split(CHECKPOINT_PATHS_DELIMITER)
    )
    for model in models:
        model.eval()
        if use_cuda:
            model.cuda()

    append_eos_to_source = model_args[0].append_eos_to_source
    reverse_source = model_args[0].reverse_source
    assert all(
        a.append_eos_to_source == append_eos_to_source
        and a.reverse_source == reverse_source
        for a in model_args
    )
    assert (
        args.source_binary_file != "" and args.target_binary_file != ""
    ), "collect_top_k_probs requires binarized data."
    task.load_dataset(args.gen_subset, args.source_binary_file, args.target_binary_file)

    assert (
        args.top_k_probs_binary_file != ""
    ), "must specify output file (--top-k-probs-binary-file)!"
    output_path = args.top_k_probs_binary_file

    dataset = task.dataset(args.gen_subset)

    top_k_scores, top_k_indices = compute_top_k(
        task=task,
        models=models,
        dataset=dataset,
        k=args.k_probs_to_collect,
        use_cuda=use_cuda,
        max_tokens=args.teacher_max_tokens,
        max_sentences=args.max_sentences,
        progress_bar_args=args,
    )

    np.savez(output_path, top_k_scores=top_k_scores, top_k_indices=top_k_indices)
    print(
        f"Saved top {top_k_scores.shape[1]} probs for a total of "
        f"{top_k_scores.shape[0]} tokens to file {output_path}"
    )
Ejemplo n.º 4
0
def generate(args):
    pytorch_translate_options.print_args(args)

    # Setup task
    task = tasks.setup_task(args)

    models, model_args = pytorch_translate_utils.load_diverse_ensemble_for_inference(
        args.path.split(":"), task)
    args.source_lang = model_args[0].source_lang
    args.target_lang = model_args[0].target_lang

    append_eos_to_source = model_args[0].append_eos_to_source
    reverse_source = model_args[0].reverse_source
    assert all(a.append_eos_to_source == append_eos_to_source
               and a.reverse_source == reverse_source for a in model_args)
    if args.source_binary_file != "":
        assert args.target_binary_file != ""
        task.load_dataset(args.gen_subset, args.source_binary_file,
                          args.target_binary_file)
    elif pytorch_translate_data.is_multilingual(args):
        task.set_encoder_langs(model_args[0].multiling_encoder_lang)
        task.set_decoder_langs(model_args[0].multiling_decoder_lang)
        task.load_dataset_from_text_multilingual(
            args.gen_subset,
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            source_lang_id=task.get_encoder_lang_id(
                args.multiling_source_lang[0]),
            target_lang_id=task.get_decoder_lang_id(
                args.multiling_target_lang[0]),
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    elif args.source_ensembling:
        task.load_multisource_dataset_from_text(
            args.gen_subset,
            source_text_files=args.source_text_file,
            target_text_file=args.target_text_file,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    else:
        task.load_dataset_from_text(
            args.gen_subset,
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )

    scorer, num_sentences, gen_timer, _ = _generate_score(models=models,
                                                          args=args,
                                                          task=task,
                                                          dataset=task.dataset(
                                                              args.gen_subset))
    print(f"| Translated {num_sentences} sentences ({gen_timer.n} tokens) "
          f"in {gen_timer.sum:.1f}s ({1. / gen_timer.avg:.2f} tokens/s)")
    print(f"| Generate {args.gen_subset} with beam={args.beam}: "
          f"{scorer.result_string()}")
    return scorer.score()
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Translate - preprocessing")
    pytorch_translate_options.add_verbosity_args(parser)
    pytorch_translate_options.add_preprocessing_args(parser)
    args = parser.parse_args()
    pytorch_translate_options.validate_preprocessing_args(args)
    pytorch_translate_options.print_args(args)
    preprocess_corpora(args)
Ejemplo n.º 6
0
def single_process_main(args):
    """Train the model for multiple epochs."""
    pytorch_translate_options.print_args(args)
    extra_state, trainer, task, epoch_itr = setup_training(args)
    train(
        args=args,
        extra_state=extra_state,
        trainer=trainer,
        task=task,
        epoch_itr=epoch_itr,
    )
Ejemplo n.º 7
0
def single_process_main(args, trainer_class=Trainer, **train_step_kwargs):
    """Train the model for multiple epochs."""
    pytorch_translate_options.print_args(args)
    trainer, task, epoch_itr = setup_training(args, trainer_class)
    extra_state, epoch_itr, checkpoint_manager = setup_training_state(
        args=args, trainer=trainer, task=task, epoch_itr=epoch_itr)
    train(
        args=args,
        extra_state=extra_state,
        trainer=trainer,
        task=task,
        epoch_itr=epoch_itr,
        checkpoint_manager=checkpoint_manager,
        **train_step_kwargs,
    )
Ejemplo n.º 8
0
    if args.distributed_world_size == 1:
        return single_process_train(args)

    mp = multiprocessing.get_context("spawn")

    # Create a thread to listen for errors in the child processes.
    error_queue = mp.SimpleQueue()
    error_handler = ErrorHandler(error_queue)

    # Train with multiprocessing.
    procs = []
    for i in range(args.distributed_world_size):
        args.distributed_rank = i
        args.device_id = i
        procs.append(
            mp.Process(target=run,
                       args=(args, single_process_train, error_queue),
                       daemon=True))
        procs[i].start()
        error_handler.add_child(procs[i].pid)
    for p in procs:
        p.join()


if __name__ == "__main__":
    parser = get_parser_with_args()
    args = adversarial_options.parse_args_and_adversary(parser)
    validate_and_set_default_args(args)
    pytorch_translate_options.print_args(args)
    main(args, single_process_main)
Ejemplo n.º 9
0
def generate(args):
    pytorch_translate_options.print_args(args)

    src_dict = pytorch_translate_dictionary.Dictionary.load(args.source_vocab_file)
    dst_dict = pytorch_translate_dictionary.Dictionary.load(args.target_vocab_file)
    use_char_source = args.char_source_vocab_file != ""
    if use_char_source:
        char_source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.char_source_vocab_file
        )
        # this attribute is used for CharSourceModel construction
        args.char_source_dict_size = len(char_source_dict)
    else:
        char_source_dict = None

    dataset = data.LanguageDatasets(
        src=args.source_lang, dst=args.target_lang, src_dict=src_dict, dst_dict=dst_dict
    )
    models, model_args = pytorch_translate_utils.load_diverse_ensemble_for_inference(
        args.path, dataset.src_dict, dataset.dst_dict
    )
    append_eos_to_source = model_args[0].append_eos_to_source
    reverse_source = model_args[0].reverse_source
    assert all(
        a.append_eos_to_source == append_eos_to_source
        and a.reverse_source == reverse_source
        for a in model_args
    )
    if args.source_binary_file != "":
        assert args.target_binary_file != ""
        dst_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
            args.target_binary_file
        )
        if use_char_source:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                args.source_binary_file
            )
            gen_split = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=src_dict.pad(),
                eos_idx=dst_dict.eos(),
            )
        else:
            src_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
                args.source_binary_file
            )
            gen_split = data.LanguagePairDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=src_dict.pad(),
                eos_idx=dst_dict.eos(),
            )
    elif pytorch_translate_data.is_multilingual(args):
        gen_split = pytorch_translate_data.make_language_pair_dataset_from_text_multilingual(
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            source_lang_id=args.multiling_source_lang_id,
            target_lang_id=args.multiling_target_lang_id,
            source_dict=src_dict,
            target_dict=dst_dict,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    elif args.source_ensembling:
        gen_split = multisource_data.make_multisource_language_pair_dataset_from_text(
            source_text_files=args.source_text_file,
            target_text_file=args.target_text_file,
            source_dict=src_dict,
            target_dict=dst_dict,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    else:
        gen_split = pytorch_translate_data.make_language_pair_dataset_from_text(
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            source_dict=src_dict,
            target_dict=dst_dict,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
            char_source_dict=char_source_dict,
        )
    dataset.splits[args.gen_subset] = gen_split

    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print(f"| [{dataset.src}] dictionary: {len(dataset.src_dict)} types")
    print(f"| [{dataset.dst}] dictionary: {len(dataset.dst_dict)} types")
    print(f"| {args.gen_subset} {len(dataset.splits[args.gen_subset])} examples")
    scorer, num_sentences, gen_timer, _ = _generate_score(
        models=models, args=args, dataset=dataset, dataset_split=args.gen_subset
    )
    print(
        f"| Translated {num_sentences} sentences ({gen_timer.n} tokens) "
        f"in {gen_timer.sum:.1f}s ({1. / gen_timer.avg:.2f} tokens/s)"
    )
    print(
        f"| Generate {args.gen_subset} with beam={args.beam}: "
        f"{scorer.result_string()}"
    )
    return scorer.score()
Ejemplo n.º 10
0
def generate(args):
    pytorch_translate_options.print_args(args)

    models, model_args, task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
        args.path.split(CHECKPOINT_PATHS_DELIMITER)
    )
    args.source_lang = model_args[0].source_lang
    args.target_lang = model_args[0].target_lang

    append_eos_to_source = model_args[0].append_eos_to_source
    reverse_source = model_args[0].reverse_source
    assert all(
        a.append_eos_to_source == append_eos_to_source
        and a.reverse_source == reverse_source
        for a in model_args
    )
    if args.source_binary_file != "":
        assert args.target_binary_file != ""
        if isinstance(task, PytorchTranslateTask):
            task.load_dataset(
                args.gen_subset,
                args.source_binary_file,
                args.target_binary_file,
                is_npz=args.is_npz,
            )
        else:
            task.load_dataset(
                args.gen_subset, args.source_binary_file, args.target_binary_file
            )
    elif pytorch_translate_data.is_multilingual_many_to_one(args):
        task.set_encoder_langs(model_args[0].multiling_encoder_lang)
        task.set_decoder_langs(model_args[0].multiling_decoder_lang)
        task.load_dataset_from_text_multilingual(
            args.gen_subset,
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            source_lang_id=task.get_encoder_lang_id(args.multiling_source_lang[0]),
            target_lang_id=task.get_decoder_lang_id(args.multiling_target_lang[0]),
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    elif args.source_ensembling:
        task.load_multisource_dataset_from_text(
            args.gen_subset,
            source_text_files=args.source_text_file,
            target_text_file=args.target_text_file,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    else:
        task.load_dataset_from_text(
            args.gen_subset,
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )

    lang_pair = None
    if isinstance(task, PyTorchTranslateMultiTask):
        if args.source_lang and args.target_lang:
            lang_pair = args.source_lang + "-" + args.target_lang
        else:
            lang_pair = "src-tgt"
    scorer, num_sentences, gen_timer, _ = generate_score(
        args=args,
        task=task,
        dataset=task.dataset(args.gen_subset),
        lang_pair=lang_pair,
        models=models,
    )
    print(
        f"| Translated {num_sentences} sentences ({gen_timer.n} tokens) "
        f"in {gen_timer.sum:.1f}s ({1. / gen_timer.avg:.2f} tokens/s)"
    )
    print(
        f"| Generate {args.gen_subset} with beam={args.beam}: "
        f"{scorer.result_string()}"
    )
    return scorer.score()