Exemplo n.º 1
0
def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    parser = options.get_training_parser()
    options.add_pruning_args(parser)
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)

    cfg = convert_namespace_to_omegaconf(args)

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
Exemplo n.º 2
0
    group.add_argument("--one-minus", action="store_true")
    group.add_argument("--one-head", action="store_true")
    group.add_argument("--encoder-self-only",
                       action="store_true",
                       help="Only prune from the encoder self attention")
    group.add_argument("--encoder-decoder-only",
                       action="store_true",
                       help="Only prune from the encoder decoder attention")
    group.add_argument("--decoder-self-only",
                       action="store_true",
                       help="Only prune from the decoder self attention")


if __name__ == '__main__':
    parser = options.get_training_parser()
    add_pruning_args(parser)
    options.add_pruning_args(parser)
    options.add_generation_args(parser)
    args = options.parse_args_and_arch(parser)

    if args.distributed_port > 0 or args.distributed_init_method is not None:
        from distributed_train import main as distributed_main

        distributed_main(args)
    elif args.distributed_world_size > 1:
        from multiprocessing_train import main as multiprocessing_main

        multiprocessing_main(args)
    else:
        main(args)