Beispiel #1
0
def pase_url(ckpt, model_config, refresh=False, **kwargs):
    """
        The model from URL
            ckpt (str): URL
            model_config (str): URL
    """
    ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
    model_config = _urls_to_filepaths(model_config, refresh=refresh)
    return pase_local(ckpt, model_config, **kwargs)
def byol_a_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    return byol_a_local(_urls_to_filepaths(ckpt, refresh=refresh), *args,
                        **kwargs)
Beispiel #3
0
def mockingjay_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    return mockingjay_local(_urls_to_filepaths(ckpt, refresh=refresh), *args,
                            **kwargs)
Beispiel #4
0
def audio_albert_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    return audio_albert_local(_urls_to_filepaths(ckpt, refresh=refresh), *args,
                              **kwargs)
Beispiel #5
0
def timit_posteriorgram_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    return timit_posteriorgram_local(_urls_to_filepaths(ckpt, refresh=refresh),
                                     *args, **kwargs)
Beispiel #6
0
def decoar_url(ckpt, refresh=False, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
    return decoar_local(ckpt, **kwargs)
Beispiel #7
0
def wav2vec2_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from google drive id
            ckpt (str): URL
            refresh (bool): whether to download ckpt/config again if existed
    """
    return wav2vec2_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
Beispiel #8
0
def spec_augment_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    return spec_augment_local(_urls_to_filepaths(ckpt, refresh=refresh), *args,
                              **kwargs)
Beispiel #9
0
def vq_wav2vec_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from google drive id
            ckpt (str): URL
            feature_selection (str): z, codewords, codeids
            refresh (bool): whether to download ckpt/config again if existed
    """
    return vq_wav2vec_local(_urls_to_filepaths(ckpt, refresh=refresh), *args,
                            **kwargs)
Beispiel #10
0
def extract(files, args):
    expdir = os.path.join(args.extract_dir, args.upstream)
    os.makedirs(expdir, exist_ok=True)

    model = getattr(hubconf, args.upstream)().to(device=args.device)
    model.eval()

    with torch.no_grad():
        for file in files:
            file_path = _urls_to_filepaths(file, refresh=args.refresh)
            wav, sr = torchaudio.load(file_path)
            wav = wav.view(-1).to(device=args.device)
            repre = model([wav])[0].detach().cpu()

            outpath = os.path.join(expdir, file.split('/')[-1] + '.pth')
            torch.save(repre, outpath)
Beispiel #11
0
def test(files, args):
    pths = [f'http://140.112.21.12:9000/extracted/{args.upstream}/' + file.split('/')[-1] + '.pth' for file in files]

    model = getattr(hubconf, args.upstream)().to(device=args.device)
    model.eval()

    with torch.no_grad():
        for file, pth in zip(files, pths):
            file_path, pth_path = _urls_to_filepaths(file, pth, refresh=args.refresh)
            wav, sr = torchaudio.load(file_path)
            wav = wav.view(-1).to(device=args.device)
            repre = model([wav])[0].detach().cpu()

            max_diff = max((repre - torch.load(pth_path)).abs().max().item(), 0.0001)
            with open(args.report, 'a') as handle:
                handle.write(f'{max_diff} atol is required for torch.allclose to pass {args.upstream} + {file}.\n')
def vq_wav2vec_kmeans_roberta(refresh=False, **kwargs):
    vq_wav2vec = getattr(hubconf, f'vq_wav2vec_kmeans')(refresh=refresh)

    tar_file = _urls_to_filepaths(
        'https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar',
        refresh=refresh)
    tar_dir = os.path.join(os.path.dirname(tar_file),
                           'vq_wav2vec_kmeans_roberta/')
    os.makedirs(tar_dir, exist_ok=True)
    os.system(f'tar -xf {tar_file} -C {tar_dir}')

    pt_files = glob.glob(os.path.join(tar_dir, '*.pt'))
    assert len(pt_files) == 1
    pt_file = pt_files[0]

    kwargs['model_name_or_path'] = tar_dir
    kwargs['checkpoint_file'] = pt_file
    return _vq_wav2vec_roberta(vq_wav2vec, **kwargs)
Beispiel #13
0
def cpc_url(ckpt, refresh=False, *args, **kwargs):
    """
        The model from URL
            ckpt (str): URL
    """
    return cpc_local(_urls_to_filepaths(ckpt), *args, **kwargs)