Beispiel #1
0
def get_model(model_name, model_config, to_fp16, to_cuda, training=True):
    """ Code chooses a model based on name"""
    model = None
    if model_name == 'Tacotron2':
        model = Tacotron2(**model_config)
    elif model_name == 'WaveGlow':
        model = WaveGlow(**model_config)
    else:
        raise NotImplementedError(model_name)
    if to_fp16:
        model = batchnorm_to_float(model.half())
        model = lstmcell_to_float(model)
        if model_name == "WaveGlow":
            for k in model.convinv:
                k.float()
    if to_cuda:
        model = model.cuda()
    return model
Beispiel #2
0
class WaveGlowInferencer(object):

    def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False):
        self.ckpt_file = ckpt_file
        self.device = device
        self.use_fp16 = use_fp16
        self.use_denoiser = use_denoiser

        # model
        # sys.path.append('waveglow')

        from waveglow.arg_parser import parse_waveglow_args
        parser = parser = argparse.ArgumentParser()
        model_parser= parse_waveglow_args(parser)
        args, _ = model_parser.parse_known_args()
        model_config = dict(
            n_mel_channels=args.n_mel_channels,
            n_flows=args.flows,
            n_group=args.groups,
            n_early_every=args.early_every,
            n_early_size=args.early_size,
            WN_config=dict(
                n_layers=args.wn_layers,
                kernel_size=args.wn_kernel_size,
                n_channels=args.wn_channels
            )
        )        
        self.model = WaveGlow(**model_config)

        state_dict = torch.load(self.ckpt_file, map_location=self.device)['state_dict']
        state_dict = unwrap_distributed(state_dict)
        self.model.load_state_dict(state_dict)

        self.model = to_device_async(self.model, self.device)

        self.model = self.model.remove_weightnorm(self.model)

        self.model.eval()

        if self.use_fp16:
            self.model = self.model.half()
        self.model = self.model

        if self.use_denoiser:
            self.denoiser = Denoiser(self.model, device=device)
            self.denoiser = to_device_async(self.denoiser, self.device)

            tprint('Using WaveGlow denoiser.')

    def __enter__(self):
        pass

    def __exit__(self, exception_type, exception_value, traceback):
        pass

    def infer(self, mels):
        if self.use_fp16:
            mels = mels.half()
        mels = to_device_async(mels, self.device)
        wavs = self.model.infer(mels, sigma=0.6)

        if self.use_denoiser:
            wavs = self.denoiser(wavs, strength=0.01)

        return wavs.float()