Exemplo n.º 1
0
def infer():
    """ CLI function to run pretrained model inference on wav files.

    Args:
        url_or_path(str): Path to the pretrained model.
        files (List(str)): Path to the wav files to separate. Also support list
            of filenames, directory names and globs.
        force_overwrite (bool): Whether to overwrite output wav files.
        output_dir (str): Output directory to save files.
    """
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("url_or_path",
                        type=str,
                        help="Path to the pretrained model.")
    parser.add_argument(
        "--files",
        default=None,
        type=str,
        help=
        "Path to the wav files to separate. Also support list of filenames, "
        "directory names and globs.",
        nargs="+",
    )
    parser.add_argument(
        "-f",
        "--force-overwrite",
        action="store_true",
        help="Whether to overwrite output wav files.",
    )
    parser.add_argument("-o",
                        "--output-dir",
                        default=None,
                        type=str,
                        help="Output directory to save files.")
    args = parser.parse_args()

    model = BaseModel.from_pretrained(
        pretrained_model_conf_or_path=args.url_or_path)
    file_list = _process_files_as_list(args.files)

    for f in file_list:
        model.separate(f,
                       force_overwrite=args.force_overwrite,
                       output_dir=args.output_dir)
Exemplo n.º 2
0
 def select_model(model_name,task):
     from asteroid.models.base_models import BaseModel
     if task == "Speech enhancement":
         # if model_name == "ConvTasNet":
         #     path = "JorisCos/ConvTasNet_Libri1Mix_enhsingle_16k"
         if model_name == "DPTNet":
             path = "JorisCos/DPTNet_Libri1Mix_enhsingle_16k"
         # elif model_name == "DPRNNTasNet":
         #     path = "JorisCos/DPRNNTasNet-ks2_Libri1Mix_enhsingle_16k"
         # elif model_name == "DCCRNet":
         #     path = "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"
         # elif model_name == "DCUNet":
         #     path = "JorisCos/DCUNet_Libri1Mix_enhsingle_16k"
     elif task == "Source separation":
         if model_name == "ConvTasNet":
             path = "JorisCos/ConvTasNet_Libri2Mix_sepclean_16k"
         if model_name == "DPRNNTasNet":
             path = "mpariente/DPRNNTasNet-ks2_WHAM_sepclean"
     return BaseModel.from_pretrained(path)
Exemplo n.º 3
0
def infer():
    """CLI function to run pretrained model inference on wav files."""
    parser = argparse.ArgumentParser()
    parser.add_argument("url_or_path",
                        type=str,
                        help="Path to the pretrained model.")
    parser.add_argument(
        "--files",
        default=None,
        required=True,
        type=str,
        help=
        "Path to the wav files to separate. Also supports list of filenames, "
        "directory names and globs.",
        nargs="+",
    )
    parser.add_argument(
        "-f",
        "--force-overwrite",
        action="store_true",
        help="Whether to overwrite output wav files.",
    )
    parser.add_argument(
        "-r",
        "--resample",
        action="store_true",
        help="Whether to resample wrong sample rate input files.",
    )
    parser.add_argument(
        "-w",
        "--ola-window",
        type=validate_window_length,
        default=None,
        help=
        "Overlap-add window to use. If not set (default), overlap-add is not used.",
    )
    parser.add_argument(
        "--ola-hop",
        type=validate_window_length,
        default=None,
        help=
        "Overlap-add hop length in samples. Defaults to ola-window // 2. Only used if --ola-window is set.",
    )
    parser.add_argument(
        "--ola-window-type",
        type=str,
        default="hanning",
        help=
        "Type of overlap-add window to use. Only used if --ola-window is set.",
    )
    parser.add_argument(
        "--ola-no-reorder",
        action="store_true",
        help=
        "Disable automatic reordering of overlap-add chunk. See asteroid.dsp.LambdaOverlapAdd for details. "
        "Only used if --ola-window is set.",
    )
    parser.add_argument("-o",
                        "--output-dir",
                        default=None,
                        type=str,
                        help="Output directory to save files.")
    args = parser.parse_args()

    model = BaseModel.from_pretrained(
        pretrained_model_conf_or_path=args.url_or_path)
    if args.ola_window is not None:
        model = LambdaOverlapAdd(
            model,
            n_src=None,
            window_size=args.ola_window,
            hop_size=args.ola_hop,
            window=args.ola_window_type,
            reorder_chunks=not args.ola_no_reorder,
        )
    file_list = _process_files_as_list(args.files)

    for f in file_list:
        separate(
            model,
            f,
            force_overwrite=args.force_overwrite,
            output_dir=args.output_dir,
            resample=args.resample,
        )