def infer(): """ CLI function to run pretrained model inference on wav files. Args: url_or_path(str): Path to the pretrained model. files (List(str)): Path to the wav files to separate. Also support list of filenames, directory names and globs. force_overwrite (bool): Whether to overwrite output wav files. output_dir (str): Output directory to save files. """ import argparse parser = argparse.ArgumentParser() parser.add_argument("url_or_path", type=str, help="Path to the pretrained model.") parser.add_argument( "--files", default=None, type=str, help= "Path to the wav files to separate. Also support list of filenames, " "directory names and globs.", nargs="+", ) parser.add_argument( "-f", "--force-overwrite", action="store_true", help="Whether to overwrite output wav files.", ) parser.add_argument("-o", "--output-dir", default=None, type=str, help="Output directory to save files.") args = parser.parse_args() model = BaseModel.from_pretrained( pretrained_model_conf_or_path=args.url_or_path) file_list = _process_files_as_list(args.files) for f in file_list: model.separate(f, force_overwrite=args.force_overwrite, output_dir=args.output_dir)
def select_model(model_name,task): from asteroid.models.base_models import BaseModel if task == "Speech enhancement": # if model_name == "ConvTasNet": # path = "JorisCos/ConvTasNet_Libri1Mix_enhsingle_16k" if model_name == "DPTNet": path = "JorisCos/DPTNet_Libri1Mix_enhsingle_16k" # elif model_name == "DPRNNTasNet": # path = "JorisCos/DPRNNTasNet-ks2_Libri1Mix_enhsingle_16k" # elif model_name == "DCCRNet": # path = "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k" # elif model_name == "DCUNet": # path = "JorisCos/DCUNet_Libri1Mix_enhsingle_16k" elif task == "Source separation": if model_name == "ConvTasNet": path = "JorisCos/ConvTasNet_Libri2Mix_sepclean_16k" if model_name == "DPRNNTasNet": path = "mpariente/DPRNNTasNet-ks2_WHAM_sepclean" return BaseModel.from_pretrained(path)
def infer(): """CLI function to run pretrained model inference on wav files.""" parser = argparse.ArgumentParser() parser.add_argument("url_or_path", type=str, help="Path to the pretrained model.") parser.add_argument( "--files", default=None, required=True, type=str, help= "Path to the wav files to separate. Also supports list of filenames, " "directory names and globs.", nargs="+", ) parser.add_argument( "-f", "--force-overwrite", action="store_true", help="Whether to overwrite output wav files.", ) parser.add_argument( "-r", "--resample", action="store_true", help="Whether to resample wrong sample rate input files.", ) parser.add_argument( "-w", "--ola-window", type=validate_window_length, default=None, help= "Overlap-add window to use. If not set (default), overlap-add is not used.", ) parser.add_argument( "--ola-hop", type=validate_window_length, default=None, help= "Overlap-add hop length in samples. Defaults to ola-window // 2. Only used if --ola-window is set.", ) parser.add_argument( "--ola-window-type", type=str, default="hanning", help= "Type of overlap-add window to use. Only used if --ola-window is set.", ) parser.add_argument( "--ola-no-reorder", action="store_true", help= "Disable automatic reordering of overlap-add chunk. See asteroid.dsp.LambdaOverlapAdd for details. " "Only used if --ola-window is set.", ) parser.add_argument("-o", "--output-dir", default=None, type=str, help="Output directory to save files.") args = parser.parse_args() model = BaseModel.from_pretrained( pretrained_model_conf_or_path=args.url_or_path) if args.ola_window is not None: model = LambdaOverlapAdd( model, n_src=None, window_size=args.ola_window, hop_size=args.ola_hop, window=args.ola_window_type, reorder_chunks=not args.ola_no_reorder, ) file_list = _process_files_as_list(args.files) for f in file_list: separate( model, f, force_overwrite=args.force_overwrite, output_dir=args.output_dir, resample=args.resample, )