Exemplo n.º 1
0
def load_waveglow (fp16=fp16):
  '''Constructs a WaveGlow model (nn.module with additional infer(input) method).
    For detailed information on model input and output, training recipies, inference and performance
    visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com

    Args:
        pretrained (bool): If True, returns a model pretrained on LJ Speech dataset.
        model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16')
  '''
  from PyTorch.SpeechSynthesis.Tacotron2.waveglow import model as waveglow
  from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float

  ckpt_file = waveglow_path
  ckpt = torch.load(ckpt_file, map_location=device)
  
  state_dict = ckpt['state_dict']
  if checkpoint_from_distributed(state_dict):
    state_dict = unwrap_distributed(state_dict)
  config = ckpt['config']

  m = waveglow.WaveGlow(**config)

  if fp16:
    m = batchnorm_to_float(m.half())
    for mat in m.convinv:
      mat.float()

  m.load_state_dict(state_dict)

  return m
Exemplo n.º 2
0
def nvidia_waveglow(pretrained=True, **kwargs):
    """Constructs a WaveGlow model (nn.module with additional infer(input) method).
    For detailed information on model input and output, training recipies, inference and performance
    visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com

    Args:
        pretrained (bool): If True, returns a model pretrained on LJ Speech dataset.
        model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16')
    """

    from PyTorch.SpeechSynthesis.Tacotron2.waveglow import model as waveglow
    from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float

    fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"

    if pretrained:
        if fp16:
            checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
        else:
            checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
        ckpt_file = "waveglow_ckpt.pt"
        urllib.request.urlretrieve(checkpoint, ckpt_file)
        ckpt = torch.load(ckpt_file)
        state_dict = ckpt['state_dict']
        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)
        config = ckpt['config']
    else:
        config = {
            'n_mel_channels': 80,
            'n_flows': 12,
            'n_group': 8,
            'n_early_every': 4,
            'n_early_size': 2,
            'WN_config': {
                'n_layers': 8,
                'kernel_size': 3,
                'n_channels': 512
            }
        }
        for k, v in kwargs.items():
            if k in config.keys():
                config[k] = v
            elif k in config['WN_config'].keys():
                config['WN_config'][k] = v

    m = waveglow.WaveGlow(**config)

    if fp16:
        m = batchnorm_to_float(m.half())
        for mat in m.convinv:
            mat.float()

    if pretrained:
        m.load_state_dict(state_dict)

    return m
Exemplo n.º 3
0
def load_waveglow(weightpath='checkpoints/waveglow_20200314.pth'):

    ckpt = torch.load(weightpath)
    state_dict = ckpt['state_dict']
    if checkpoint_from_distributed(state_dict):
        state_dict = unwrap_distributed(state_dict)
    config = ckpt['config']

    waveglow_model = waveglow.WaveGlow(**config)
    waveglow_model.load_state_dict(state_dict)
    return waveglow_model