def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
          sigma, gate_threshold, seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # load waveglow
    waveglow = torch.load(waveglow_path)['model'].cuda().eval()
    waveglow.cuda().half()
    for k in waveglow.convinv:
        k.float()
    waveglow.eval()

    # load flowtron
    model = Flowtron(**model_config).cuda()
    state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    print("Loaded checkpoint '{}')".format(flowtron_path))

    ignore_keys = ['training_files', 'validation_files']
    trainset = Data(
        data_config['training_files'],
        **dict((k, v) for k, v in data_config.items() if k not in ignore_keys))
    speaker_vecs = trainset.get_speaker_id(speaker_id).cuda()
    text = trainset.get_text(text).cuda()
    speaker_vecs = speaker_vecs[None]
    text = text[None]

    with torch.no_grad():
        residual = torch.cuda.FloatTensor(1, 80, n_frames).normal_() * sigma
        mels, attentions = model.infer(residual,
                                       speaker_vecs,
                                       text,
                                       gate_threshold=gate_threshold)

    for k in range(len(attentions)):
        attention = torch.cat(attentions[k]).cpu().numpy()
        fig, axes = plt.subplots(1, 2, figsize=(16, 4))
        axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
        axes[1].imshow(attention[:, 0].transpose(),
                       origin='bottom',
                       aspect='auto')
        fig.savefig(
            os.path.join(
                output_dir,
                'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k)))
        plt.close("all")

    with torch.no_grad():
        audio = waveglow.infer(mels.half(), sigma=0.8).float()

    audio = audio.cpu().numpy()[0]
    # normalize audio for now
    audio = audio / np.abs(audio).max()
    print(audio.shape)

    write(
        os.path.join(output_dir, 'sid{}_sigma{}.wav'.format(speaker_id,
                                                            sigma)),
        data_config['sampling_rate'], audio)
示例#2
0
def setup():
    # Parse configs.  Globals nicer in this case
    with open("flowtron/infer.json") as f:
        data = f.read()

    global config
    config = json.loads(data)

    global data_config
    data_config = config["data_config"]
    global model_config
    model_config = config["model_config"]

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False

    global flowtron
    global waveglow
    global trainset
    
    encoder_weights = Path("encoder/saved_models/pretrained.pt")
    encoder.load_model(encoder_weights)

    torch.manual_seed(1234)
    torch.cuda.manual_seed(1234)

    #Load waveglow
    waveglow = torch.load("flowtron/tacotron2/waveglow/saved_models/waveglow_256channels_universal_v5.pt")['model'].cuda().eval()
    waveglow.cuda().half()
    for k in waveglow.convinv:
        k.float()
    waveglow.eval()
    
    #Load flowtron
    flowtron = Flowtron(**model_config).cuda()
    state_dict = torch.load("flowtron/saved_models/pretrained.pt", map_location='cpu')['model'].state_dict()
    flowtron.load_state_dict(state_dict)
    flowtron.eval()

    ignore_keys = ['training_files', 'validation_files']
    trainset = Data(
        data_config['training_files'],
        **dict((k, v) for k, v in data_config.items() if k not in ignore_keys))
示例#3
0
def load_models(flowtron_path, waveglow_path):
    # load waveglow
    waveglow = torch.load(waveglow_path)['model'].cuda().eval()
    waveglow.cuda()
    for k in waveglow.convinv:
        k.float()
    waveglow.eval()

    # load flowtron
    try:
        model = Flowtron(**model_config).cuda()
        state_dict = torch.load(flowtron_path,
                                map_location='cpu')['state_dict']
        model.load_state_dict(state_dict)
    except KeyError:
        model = torch.load(flowtron_path)['model']

    model.eval()
    print("Loaded model '{}')".format(flowtron_path))

    return model, waveglow
示例#4
0
def infer(flowtron_path, waveglow_path, text, speaker_id, n_frames, sigma,
          seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # load waveglow
    waveglow = torch.load(waveglow_path)['model'].cuda().eval()
    waveglow.cuda().half()
    for k in waveglow.convinv:
        k.float()
    waveglow.eval()

    # load flowtron
    model = Flowtron(**model_config).cuda()
    cpt_dict = torch.load(flowtron_path)
    if 'model' in cpt_dict:
        dummy_dict = cpt_dict['model'].state_dict()
    else:
        dummy_dict = cpt_dict['state_dict']
    model.load_state_dict(dummy_dict)
    model.eval()

    print("Loaded checkpoint '{}')".format(flowtron_path))

    ignore_keys = ['training_files', 'validation_files']
    trainset = Data(
        data_config['training_files'],
        **dict((k, v) for k, v in data_config.items() if k not in ignore_keys))

    tic_prep = time.time()

    str_text = text
    num_char = len(str_text)
    num_word = len(str_text.split())

    speaker_vecs = trainset.get_speaker_id(speaker_id).cuda()
    text = trainset.get_text(text).cuda()

    speaker_vecs = speaker_vecs[None]
    text = text[None]
    toc_prep = time.time()

    ############## warm up   ########### to measure exact flowtron inference time

    with torch.no_grad():
        tic_warmup = time.time()
        residual = torch.cuda.FloatTensor(1, 80, n_frames).normal_() * sigma
        mels, attentions = model.infer(residual, speaker_vecs, text)
        toc_warmup = time.time()

    tic_flowtron = time.time()
    with torch.no_grad(), torch.autograd.profiler.emit_nvtx(
    ):  ########### prof.
        tic_residual = time.time()
        residual = torch.cuda.FloatTensor(1, 80, n_frames).normal_() * sigma
        toc_residual = time.time()
        profiler.start()  ########### prof.
        mels, attentions = model.infer(residual, speaker_vecs, text)
        profiler.stop()  ########### prof.
        toc_flowtron = time.time()

    for k in range(len(attentions)):
        attention = torch.cat(attentions[k]).cpu().numpy()
        fig, axes = plt.subplots(1, 2, figsize=(16, 4))
        axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
        axes[1].imshow(attention[:, 0].transpose(),
                       origin='bottom',
                       aspect='auto')
        fig.savefig('sid{}_sigma{}_attnlayer{}.png'.format(
            speaker_id, sigma, k))
        plt.close("all")

    tic_waveglow = time.time()
    audio = waveglow.infer(mels.half(), sigma=0.8).float()
    toc_waveglow = time.time()

    audio = audio.cpu().numpy()[0]
    # normalize audio for now
    audio = audio / np.abs(audio).max()

    len_audio = len(audio)
    dur_audio = len_audio / 22050
    num_frames = int(len_audio / 256)

    dur_prep = toc_prep - tic_prep
    dur_residual = toc_residual - tic_residual
    dur_flowtron_in = toc_flowtron - toc_residual
    dur_warmup = toc_warmup - tic_warmup
    dur_flowtron_out = toc_flowtron - tic_residual
    dur_waveglow = toc_waveglow - tic_waveglow
    dur_total = dur_prep + dur_flowtron_out + dur_waveglow

    RTF = dur_audio / dur_total

    str_text = "\n text : " + str_text
    str_num = "\n text {:d} char {:d} words  ".format(num_char, num_word)
    str_audio = "\n generated audio : {:2.3f} samples  {:2.3f} sec  with  {:d} mel frames ".format(
        len_audio, dur_audio, num_frames)
    str_perf = "\n total time {:2.3f} = text prep {:2.3f} + flowtron{:2.3f} + wg {:2.3f}  ".format(
        dur_total, dur_prep, dur_flowtron_out, dur_waveglow)
    str_flow = "\n total flowtron {:2.3f} = residual cal {:2.3f} + flowtron {:2.3f}  ".format(
        dur_flowtron_out, dur_residual, dur_flowtron_in)
    str_rtf = "\n RTF is {:2.3f} x  with warm up {:2.3f} ".format(
        RTF, dur_warmup)

    print(str_text, str_num, str_audio, str_perf, str_flow, str_rtf)

    write("sid{}_sigma{}.wav".format(speaker_id, sigma),
          data_config['sampling_rate'], audio)
class AudioGeneratorFlowtron:
    models = {
        'flowtron': 'flowtron_model.pt',
    }

    waveglow = {'default': 'waveglow_256channels_universal_v5.pt'}

    def __init__(self):
        self.config_path = 'flowtron/config.json'
        self.models_path = os.getcwd() + '/models/'
        self.training_files_path = os.getcwd() + '/filelists/dataset_train.txt'
        with open(self.config_path) as f:
            data = f.read()
        self.config = json.loads(data)
        self.config['model_config']['n_speakers'] = 41
        self.lambd = 0.001
        self.sigma = 0.85
        self.waveglow_sigma = 1
        self.n_frames = 1800
        self.aggregation_type = 'batch'

        self.model = Flowtron(**self.config['model_config']).cuda()
        flowtron_path = self.models_path + self.models['flowtron']
        waveglow_path = self.models_path + self.waveglow['default']

        if 'state_dict' in torch.load(flowtron_path, map_location='cpu'):
            load = torch.load(flowtron_path, map_location='cpu')
            state_dict = load['state_dict']
        else:
            load = torch.load(flowtron_path, map_location='cpu')
            state_dict = load['model'].state_dict()
        self.model.load_state_dict(state_dict, strict=False)
        self.model.eval()

        self.waveglow = torch.load(waveglow_path)['model']
        self.waveglow.cuda().eval()

        self.z_baseline = torch.cuda.FloatTensor(
            1, 80, self.n_frames).normal_() * self.sigma

        ignore_keys = ['training_files', 'validation_files']
        self.trainset = Data(
            self.training_files_path,
            **dict((k, v) for k, v in self.config['data_config'].items()
                   if k not in ignore_keys))

    def generate(self, text: str, speaker: int):
        speaker_vecs = self.trainset.get_speaker_id(speaker).cuda()
        speaker_vecs = speaker_vecs[None]
        text = self.trainset.get_text(text).cuda()
        text = text[None]

        with torch.no_grad():
            mel_baseline = self.model.infer(self.z_baseline, speaker_vecs,
                                            text)[0]

        with torch.no_grad():
            audio_base = self.waveglow.infer(mel_baseline,
                                             sigma=self.waveglow_sigma)

        audio = audio_base[0].data.cpu().numpy()
        return audio

    def prepare_dataset(self, dataset_path):
        dataset = Data(
            dataset_path,
            **dict((k, v) for k, v in self.config['data_config'].items()
                   if k not in ['training_files', 'validation_files']))
        return dataset