def pase_url(ckpt, model_config, refresh=False, **kwargs): """ The model from URL ckpt (str): URL model_config (str): URL """ ckpt = _urls_to_filepaths(ckpt, refresh=refresh) model_config = _urls_to_filepaths(model_config, refresh=refresh) return pase_local(ckpt, model_config, **kwargs)
def byol_a_url(ckpt, refresh=False, *args, **kwargs): """ The model from URL ckpt (str): URL """ return byol_a_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def mockingjay_url(ckpt, refresh=False, *args, **kwargs): """ The model from URL ckpt (str): URL """ return mockingjay_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def audio_albert_url(ckpt, refresh=False, *args, **kwargs): """ The model from URL ckpt (str): URL """ return audio_albert_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def timit_posteriorgram_url(ckpt, refresh=False, *args, **kwargs): """ The model from URL ckpt (str): URL """ return timit_posteriorgram_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def decoar_url(ckpt, refresh=False, **kwargs): """ The model from URL ckpt (str): URL """ ckpt = _urls_to_filepaths(ckpt, refresh=refresh) return decoar_local(ckpt, **kwargs)
def wav2vec2_url(ckpt, refresh=False, *args, **kwargs): """ The model from google drive id ckpt (str): URL refresh (bool): whether to download ckpt/config again if existed """ return wav2vec2_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def spec_augment_url(ckpt, refresh=False, *args, **kwargs): """ The model from URL ckpt (str): URL """ return spec_augment_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def vq_wav2vec_url(ckpt, refresh=False, *args, **kwargs): """ The model from google drive id ckpt (str): URL feature_selection (str): z, codewords, codeids refresh (bool): whether to download ckpt/config again if existed """ return vq_wav2vec_local(_urls_to_filepaths(ckpt, refresh=refresh), *args, **kwargs)
def extract(files, args): expdir = os.path.join(args.extract_dir, args.upstream) os.makedirs(expdir, exist_ok=True) model = getattr(hubconf, args.upstream)().to(device=args.device) model.eval() with torch.no_grad(): for file in files: file_path = _urls_to_filepaths(file, refresh=args.refresh) wav, sr = torchaudio.load(file_path) wav = wav.view(-1).to(device=args.device) repre = model([wav])[0].detach().cpu() outpath = os.path.join(expdir, file.split('/')[-1] + '.pth') torch.save(repre, outpath)
def test(files, args): pths = [f'http://140.112.21.12:9000/extracted/{args.upstream}/' + file.split('/')[-1] + '.pth' for file in files] model = getattr(hubconf, args.upstream)().to(device=args.device) model.eval() with torch.no_grad(): for file, pth in zip(files, pths): file_path, pth_path = _urls_to_filepaths(file, pth, refresh=args.refresh) wav, sr = torchaudio.load(file_path) wav = wav.view(-1).to(device=args.device) repre = model([wav])[0].detach().cpu() max_diff = max((repre - torch.load(pth_path)).abs().max().item(), 0.0001) with open(args.report, 'a') as handle: handle.write(f'{max_diff} atol is required for torch.allclose to pass {args.upstream} + {file}.\n')
def vq_wav2vec_kmeans_roberta(refresh=False, **kwargs): vq_wav2vec = getattr(hubconf, f'vq_wav2vec_kmeans')(refresh=refresh) tar_file = _urls_to_filepaths( 'https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar', refresh=refresh) tar_dir = os.path.join(os.path.dirname(tar_file), 'vq_wav2vec_kmeans_roberta/') os.makedirs(tar_dir, exist_ok=True) os.system(f'tar -xf {tar_file} -C {tar_dir}') pt_files = glob.glob(os.path.join(tar_dir, '*.pt')) assert len(pt_files) == 1 pt_file = pt_files[0] kwargs['model_name_or_path'] = tar_dir kwargs['checkpoint_file'] = pt_file return _vq_wav2vec_roberta(vq_wav2vec, **kwargs)
def cpc_url(ckpt, refresh=False, *args, **kwargs): """ The model from URL ckpt (str): URL """ return cpc_local(_urls_to_filepaths(ckpt), *args, **kwargs)