def load_waveglow (fp16=fp16): '''Constructs a WaveGlow model (nn.module with additional infer(input) method). For detailed information on model input and output, training recipies, inference and performance visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com Args: pretrained (bool): If True, returns a model pretrained on LJ Speech dataset. model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16') ''' from PyTorch.SpeechSynthesis.Tacotron2.waveglow import model as waveglow from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float ckpt_file = waveglow_path ckpt = torch.load(ckpt_file, map_location=device) state_dict = ckpt['state_dict'] if checkpoint_from_distributed(state_dict): state_dict = unwrap_distributed(state_dict) config = ckpt['config'] m = waveglow.WaveGlow(**config) if fp16: m = batchnorm_to_float(m.half()) for mat in m.convinv: mat.float() m.load_state_dict(state_dict) return m
def nvidia_waveglow(pretrained=True, **kwargs): """Constructs a WaveGlow model (nn.module with additional infer(input) method). For detailed information on model input and output, training recipies, inference and performance visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com Args: pretrained (bool): If True, returns a model pretrained on LJ Speech dataset. model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16') """ from PyTorch.SpeechSynthesis.Tacotron2.waveglow import model as waveglow from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16" if pretrained: if fp16: checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306' else: checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306' ckpt_file = "waveglow_ckpt.pt" urllib.request.urlretrieve(checkpoint, ckpt_file) ckpt = torch.load(ckpt_file) state_dict = ckpt['state_dict'] if checkpoint_from_distributed(state_dict): state_dict = unwrap_distributed(state_dict) config = ckpt['config'] else: config = { 'n_mel_channels': 80, 'n_flows': 12, 'n_group': 8, 'n_early_every': 4, 'n_early_size': 2, 'WN_config': { 'n_layers': 8, 'kernel_size': 3, 'n_channels': 512 } } for k, v in kwargs.items(): if k in config.keys(): config[k] = v elif k in config['WN_config'].keys(): config['WN_config'][k] = v m = waveglow.WaveGlow(**config) if fp16: m = batchnorm_to_float(m.half()) for mat in m.convinv: mat.float() if pretrained: m.load_state_dict(state_dict) return m
def load_waveglow(weightpath='checkpoints/waveglow_20200314.pth'): ckpt = torch.load(weightpath) state_dict = ckpt['state_dict'] if checkpoint_from_distributed(state_dict): state_dict = unwrap_distributed(state_dict) config = ckpt['config'] waveglow_model = waveglow.WaveGlow(**config) waveglow_model.load_state_dict(state_dict) return waveglow_model