示例#1
0
def test_separate():
    nnet = ConvTasNet(
        n_src=2,
        n_repeats=2,
        n_blocks=3,
        bn_chan=16,
        hid_chan=4,
        skip_chan=8,
        n_filters=32,
    )
    # Test torch input
    wav = torch.rand(1, 1, 8000)
    model = LambdaOverlapAdd(nnet, None, window_size=1000)
    out = separate(model, wav)
示例#2
0
def infer():
    """CLI function to run pretrained model inference on wav files."""
    parser = argparse.ArgumentParser()
    parser.add_argument("url_or_path",
                        type=str,
                        help="Path to the pretrained model.")
    parser.add_argument(
        "--files",
        default=None,
        required=True,
        type=str,
        help=
        "Path to the wav files to separate. Also supports list of filenames, "
        "directory names and globs.",
        nargs="+",
    )
    parser.add_argument(
        "-f",
        "--force-overwrite",
        action="store_true",
        help="Whether to overwrite output wav files.",
    )
    parser.add_argument(
        "-r",
        "--resample",
        action="store_true",
        help="Whether to resample wrong sample rate input files.",
    )
    parser.add_argument(
        "-w",
        "--ola-window",
        type=validate_window_length,
        default=None,
        help=
        "Overlap-add window to use. If not set (default), overlap-add is not used.",
    )
    parser.add_argument(
        "--ola-hop",
        type=validate_window_length,
        default=None,
        help=
        "Overlap-add hop length in samples. Defaults to ola-window // 2. Only used if --ola-window is set.",
    )
    parser.add_argument(
        "--ola-window-type",
        type=str,
        default="hanning",
        help=
        "Type of overlap-add window to use. Only used if --ola-window is set.",
    )
    parser.add_argument(
        "--ola-no-reorder",
        action="store_true",
        help=
        "Disable automatic reordering of overlap-add chunk. See asteroid.dsp.LambdaOverlapAdd for details. "
        "Only used if --ola-window is set.",
    )
    parser.add_argument("-o",
                        "--output-dir",
                        default=None,
                        type=str,
                        help="Output directory to save files.")
    args = parser.parse_args()

    model = BaseModel.from_pretrained(
        pretrained_model_conf_or_path=args.url_or_path)
    if args.ola_window is not None:
        model = LambdaOverlapAdd(
            model,
            n_src=None,
            window_size=args.ola_window,
            hop_size=args.ola_hop,
            window=args.ola_window_type,
            reorder_chunks=not args.ola_no_reorder,
        )
    file_list = _process_files_as_list(args.files)

    for f in file_list:
        separate(
            model,
            f,
            force_overwrite=args.force_overwrite,
            output_dir=args.output_dir,
            resample=args.resample,
        )