def get_model(model_name, model_config, to_fp16, to_cuda, training=True): """ Code chooses a model based on name""" model = None if model_name == 'Tacotron2': model = Tacotron2(**model_config) elif model_name == 'WaveGlow': model = WaveGlow(**model_config) else: raise NotImplementedError(model_name) if to_fp16: model = batchnorm_to_float(model.half()) model = lstmcell_to_float(model) if model_name == "WaveGlow": for k in model.convinv: k.float() if to_cuda: model = model.cuda() return model
class WaveGlowInferencer(object): def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False): self.ckpt_file = ckpt_file self.device = device self.use_fp16 = use_fp16 self.use_denoiser = use_denoiser # model # sys.path.append('waveglow') from waveglow.arg_parser import parse_waveglow_args parser = parser = argparse.ArgumentParser() model_parser= parse_waveglow_args(parser) args, _ = model_parser.parse_known_args() model_config = dict( n_mel_channels=args.n_mel_channels, n_flows=args.flows, n_group=args.groups, n_early_every=args.early_every, n_early_size=args.early_size, WN_config=dict( n_layers=args.wn_layers, kernel_size=args.wn_kernel_size, n_channels=args.wn_channels ) ) self.model = WaveGlow(**model_config) state_dict = torch.load(self.ckpt_file, map_location=self.device)['state_dict'] state_dict = unwrap_distributed(state_dict) self.model.load_state_dict(state_dict) self.model = to_device_async(self.model, self.device) self.model = self.model.remove_weightnorm(self.model) self.model.eval() if self.use_fp16: self.model = self.model.half() self.model = self.model if self.use_denoiser: self.denoiser = Denoiser(self.model, device=device) self.denoiser = to_device_async(self.denoiser, self.device) tprint('Using WaveGlow denoiser.') def __enter__(self): pass def __exit__(self, exception_type, exception_value, traceback): pass def infer(self, mels): if self.use_fp16: mels = mels.half() mels = to_device_async(mels, self.device) wavs = self.model.infer(mels, sigma=0.6) if self.use_denoiser: wavs = self.denoiser(wavs, strength=0.01) return wavs.float()