def load_waveglow(filename, waveglow_config): class RenamingUnpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'glow': module = 'waveglow.model' return super().find_class(module, name) class RenamingPickleModule: def load(self, f, *args, **kw_args): return self.Unpickler(f, *args, **kw_args).load() def Unpickler(self, f, **pickle_load_args): return RenamingUnpickler(f, **pickle_load_args) pickle_module = RenamingPickleModule() blob = torch.load(filename, pickle_module=pickle_module) if 'state_dict' in blob: waveglow = WaveGlow(**waveglow_config).cuda() state_dict = {} for key, value in blob["state_dict"].items(): newKey = key if key.startswith("module."): newKey = key[len("module."):] state_dict[newKey] = value waveglow.load_state_dict(state_dict) else: waveglow = blob['model'] waveglow = split_cond_layers(waveglow) waveglow = waveglow.remove_weightnorm(waveglow) waveglow.cuda().eval() return waveglow
class WaveGlowInferencer(object): def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False): self.ckpt_file = ckpt_file self.device = device self.use_fp16 = use_fp16 self.use_denoiser = use_denoiser # model # sys.path.append('waveglow') from waveglow.arg_parser import parse_waveglow_args parser = parser = argparse.ArgumentParser() model_parser= parse_waveglow_args(parser) args, _ = model_parser.parse_known_args() model_config = dict( n_mel_channels=args.n_mel_channels, n_flows=args.flows, n_group=args.groups, n_early_every=args.early_every, n_early_size=args.early_size, WN_config=dict( n_layers=args.wn_layers, kernel_size=args.wn_kernel_size, n_channels=args.wn_channels ) ) self.model = WaveGlow(**model_config) state_dict = torch.load(self.ckpt_file, map_location=self.device)['state_dict'] state_dict = unwrap_distributed(state_dict) self.model.load_state_dict(state_dict) self.model = to_device_async(self.model, self.device) self.model = self.model.remove_weightnorm(self.model) self.model.eval() if self.use_fp16: self.model = self.model.half() self.model = self.model if self.use_denoiser: self.denoiser = Denoiser(self.model, device=device) self.denoiser = to_device_async(self.denoiser, self.device) tprint('Using WaveGlow denoiser.') def __enter__(self): pass def __exit__(self, exception_type, exception_value, traceback): pass def infer(self, mels): if self.use_fp16: mels = mels.half() mels = to_device_async(mels, self.device) wavs = self.model.infer(mels, sigma=0.6) if self.use_denoiser: wavs = self.denoiser(wavs, strength=0.01) return wavs.float()