Example #1
0
def load_model(weights_fpath, verbose=False):
    global _model, _device

    if verbose:
        print("Building Wave-RNN")
    _model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                     fc_dims=hp.voc_fc_dims,
                     bits=hp.bits,
                     pad=hp.voc_pad,
                     upsample_factors=hp.voc_upsample_factors,
                     feat_dims=hp.num_mels,
                     compute_dims=hp.voc_compute_dims,
                     res_out_dims=hp.voc_res_out_dims,
                     res_blocks=hp.voc_res_blocks,
                     hop_length=hp.hop_length,
                     sample_rate=hp.sample_rate,
                     mode=hp.voc_mode)

    if torch.cuda.is_available():
        _model = _model.cuda()
        _device = torch.device('cuda')
    else:
        _device = torch.device('cpu')

    if verbose:
        print("Loading model weights at %s" % weights_fpath)
    checkpoint = torch.load(weights_fpath, _device)
    _model.load_state_dict(checkpoint['model_state'])
    _model.eval()

    print(type(_model))
def load_model(weights_fpath, verbose=True):
    global _model
    
    if verbose:
        print("Building Wave-RNN")
    _model = WaveRNN(
        rnn_dims=hp.voc_rnn_dims,
        fc_dims=hp.voc_fc_dims,
        bits=hp.bits,
        pad=hp.voc_pad,
        upsample_factors=hp.voc_upsample_factors,
        feat_dims=hp.num_mels,
        compute_dims=hp.voc_compute_dims,
        res_out_dims=hp.voc_res_out_dims,
        res_blocks=hp.voc_res_blocks,
        hop_length=hp.hop_length,
        sample_rate=hp.sample_rate,
        mode=hp.voc_mode
    ).cuda()
    
    if verbose:
        print("Loading model weights at %s" % weights_fpath)
    checkpoint = torch.load(weights_fpath)
    _model.load_state_dict(checkpoint['model_state'])
    _model.eval()
Example #3
0
class Model(object):
    def __init__(self):
        self._model = None

    def load_from(self, weights_fpath, verbose=True):
        if verbose:
            print("Building Wave-RNN")
        self._model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                              fc_dims=hp.voc_fc_dims,
                              bits=hp.bits,
                              pad=hp.voc_pad,
                              upsample_factors=hp.voc_upsample_factors,
                              feat_dims=hp.num_mels,
                              compute_dims=hp.voc_compute_dims,
                              res_out_dims=hp.voc_res_out_dims,
                              res_blocks=hp.voc_res_blocks,
                              hop_length=hp.hop_length,
                              sample_rate=hp.sample_rate,
                              mode=hp.voc_mode)  #.cuda()

        if verbose:
            print("Loading model weights at %s" % weights_fpath)
        checkpoint = torch.load(weights_fpath,
                                map_location=torch.device('cpu'))
        self._model.load_state_dict(checkpoint['model_state'])
        self._model.eval()

    def is_loaded(self):
        return self._model is not None

    def infer_waveform(self,
                       mel,
                       normalize=True,
                       batched=True,
                       target=8000,
                       overlap=800,
                       progress_callback=None):
        """
        Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
        that of the synthesizer!)

        :param normalize:
        :param batched:
        :param target:
        :param overlap:
        :return:
        """
        if self._model is None:
            raise Exception("Please load Wave-RNN in memory before using it")

        if normalize:
            mel = mel / hp.mel_max_abs_value
        mel = torch.from_numpy(mel[None, ...])
        wav = self._model.generate(mel, batched, target, overlap, hp.mu_law,
                                   progress_callback)
        return wav
Example #4
0
def load_model(weights_fpath, verbose=True):
    global _model

    if verbose:
        print("Building Wave-RNN")
    _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                     fc_dims=hp.voc_fc_dims,
                     bits=hp.bits,
                     pad=hp.voc_pad,
                     upsample_factors=hp.voc_upsample_factors,
                     feat_dims=hp.num_mels,
                     compute_dims=hp.voc_compute_dims,
                     res_out_dims=hp.voc_res_out_dims,
                     res_blocks=hp.voc_res_blocks,
                     hop_length=hp.hop_length,
                     sample_rate=hp.sample_rate,
                     mode=hp.voc_mode).to(_device)

    if verbose:
        print("Loading model weights at %s" % weights_fpath)
    checkpoint = torch.load(str(weights_fpath), map_location=_device)
    _model.load_state_dict(checkpoint['model_state'])
    _model.eval()