예제 #1
0
def main():

    # Parse Arguments
    parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
    parser.add_argument('--lr',
                        '-l',
                        type=float,
                        help='[float] override hparams.py learning rate')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        help='[int] override hparams.py batch size')
    parser.add_argument('--force_train',
                        '-f',
                        action='store_true',
                        help='Forces the model to train past total steps')
    parser.add_argument('--gta',
                        '-g',
                        action='store_true',
                        help='train wavernn on GTA features')
    parser.add_argument(
        '--force_cpu',
        '-c',
        action='store_true',
        help='Forces CPU-only training, even when in CUDA capable environment')
    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')
    args = parser.parse_args()

    # Set hyperparameters
    hp.training_files = "tacotron2/filelists/transcripts_korean_final_final.txt"
    hp.validation_files = "tacotron2/filelists/transcripts_korean_final_validate.txt"
    hp.filter_length = 1024
    hp.n_mel_channels = 80
    hp.sampling_rate = 16000
    hp.mel_fmin = 0.0
    hp.mel_fmax = 8000.0
    hp.max_wav_value = 32768.0
    hp.n_frames_per_step = 1
    hp.configure(args.hp_file)  # load hparams from file
    if args.lr is None:
        args.lr = hp.voc_lr
    if args.batch_size is None:
        args.batch_size = hp.voc_batch_size

    paths = Paths("../data/", hp.voc_model_id, hp.tts_model_id)

    batch_size = 64
    force_train = args.force_train
    train_gta = args.gta
    lr = args.lr

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
        if batch_size % torch.cuda.device_count() != 0:
            raise ValueError(
                '`batch_size` must be evenly divisible by n_gpus!')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    print('\nInitialising Model...\n')

    # Instantiate WaveRNN Model
    voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                        fc_dims=hp.voc_fc_dims,
                        bits=hp.bits,
                        pad=hp.voc_pad,
                        upsample_factors=hp.voc_upsample_factors,
                        feat_dims=hp.num_mels,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).to(device)

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    optimizer = optim.Adam(voc_model.parameters())
    restore_checkpoint('voc',
                       paths,
                       voc_model,
                       optimizer,
                       create_if_missing=True)

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size,
                                               train_gta, hp)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

    voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set,
                   lr, total_steps)

    print('Training Complete.')
    print(
        'To continue training increase voc_total_steps in hparams.py or use --force_train'
    )
예제 #2
0
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path,
          ground_truth: bool, save_every: int, backup_every: int,
          force_restart: bool):
    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    # Instantiate the model
    print("Initializing the model...")
    model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                    fc_dims=hp.voc_fc_dims,
                    bits=hp.bits,
                    pad=hp.voc_pad,
                    upsample_factors=hp.voc_upsample_factors,
                    feat_dims=hp.num_mels,
                    compute_dims=hp.voc_compute_dims,
                    res_out_dims=hp.voc_res_out_dims,
                    res_blocks=hp.voc_res_blocks,
                    hop_length=hp.hop_length,
                    sample_rate=hp.sample_rate,
                    mode=hp.voc_mode).cuda()

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())
    for p in optimizer.param_groups:
        p["lr"] = hp.voc_lr
    loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss

    # Load the weights
    model_dir = models_dir.joinpath(run_id)
    model_dir.mkdir(exist_ok=True)
    weights_fpath = model_dir.joinpath(run_id + ".pt")
    if force_restart or not weights_fpath.exists():
        print("\nStarting the training of WaveRNN from scratch\n")
        model.save(weights_fpath, optimizer)
    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(weights_fpath, optimizer)
        print("WaveRNN weights loaded from step %d" % model.step)

    # Initialize the dataset
    metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
        voc_dir.joinpath("synthesized.txt")
    mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath(
        "mels_gta")
    wav_dir = syn_dir.joinpath("audio")
    dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
    test_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True)

    # Begin the training
    simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr),
                  ('Sequence Len', hp.voc_seq_len)])

    for epoch in range(1, 350):
        data_loader = DataLoader(dataset,
                                 collate_fn=collate_vocoder,
                                 batch_size=hp.voc_batch_size,
                                 num_workers=2,
                                 shuffle=True,
                                 pin_memory=True)
        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(data_loader, 1):
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            # Forward pass
            y_hat = model(x, m)
            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y = y.float()
            y = y.unsqueeze(-1)
            print("y shape:", y.shape)
            print("y_hat shape:", y_hat.shape)
            # Backward pass
            loss = loss_func(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if backup_every != 0 and step % backup_every == 0:
                model.checkpoint(model_dir, optimizer)

            if save_every != 0 and step % save_every == 0:
                model.save(weights_fpath, optimizer)

            msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
                f"Loss: {avg_loss:.4f} | {speed:.1f} " \
                f"steps/s | Step: {k}k | "
            stream(msg)

        gen_testset(model, test_loader, hp.voc_gen_at_checkpoint,
                    hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                    model_dir)
        print("")
예제 #3
0
def train(run_id='',
        syn_dir=None, voc_dirs=[], mel_dir_name='', models_dir=None, log_dir='',
        ground_truth=False,
        save_every=1000, backup_every=1000, log_every=1000,
        force_restart=False, total_epochs=10000, logger=None):
    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    # Instantiate the model
    print("Initializing the model...")
    model = WaveRNN(
        rnn_dims=hp.voc_rnn_dims, # 512
        fc_dims=hp.voc_fc_dims, # 512
        bits=hp.bits, # 9
        pad=hp.voc_pad, # 2
        upsample_factors=hp.voc_upsample_factors, # (3, 4, 5, 5) -> 300, (5,5,12)?
        feat_dims=hp.num_mels, # 80
        compute_dims=hp.voc_compute_dims, # 128
        res_out_dims=hp.voc_res_out_dims, # 128
        res_blocks=hp.voc_res_blocks, # 10
        hop_length=hp.hop_length, # 300
        sample_rate=hp.sample_rate, # 24000
        mode=hp.voc_mode # RAW (or MOL)
    ).cuda()

    # hp.apply_preemphasis in VocoderDataset
    # hp.mu_law in VocoderDataset
    # hp.voc_seq_len in VocoderDataset
    # hp.voc_lr in optimizer
    # hp.voc_batch_size for train

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())
    for p in optimizer.param_groups:
        p["lr"] = hp.voc_lr # 0.0001
    loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss

    # Load the weights
    model_dir = models_dir.joinpath(run_id) # gta_model/gtaxxxx
    model_dir.mkdir(exist_ok=True)
    weights_fpath = model_dir.joinpath(run_id + ".pt") # gta_model/gtaxxx/gtaxxx.pt
    if force_restart or not weights_fpath.exists():
        print("\nStarting the training of WaveRNN from scratch\n")
        model.save(str(weights_fpath), optimizer)
    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(str(weights_fpath), optimizer)
        print("WaveRNN weights loaded from step %d" % model.step)

    # Initialize the dataset
    #metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
    #    voc_dir.joinpath("synthesized.txt")
    #mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
    #wav_dir = syn_dir.joinpath("audio")
    #dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
    #dataset = VocoderDataset(str(voc_dir), 'mels-gta-1099579078086', 'audio')
    dataset = VocoderDataset([str(voc_dir) for voc_dir in voc_dirs], mel_dir_name, 'audio')
    #test_loader = DataLoader(dataset,
    #                         batch_size=1,
    #                         shuffle=True,
    #                         pin_memory=True)

    # Begin the training
    simple_table([('Batch size', hp.voc_batch_size),
                  ('LR', hp.voc_lr),
                  ('Sequence Len', hp.voc_seq_len)])

    for epoch in range(1, total_epochs):
        data_loader = DataLoader(dataset,
                                 collate_fn=collate_vocoder,
                                 batch_size=hp.voc_batch_size,
                                 num_workers=30,
                                 shuffle=True,
                                 pin_memory=True)
        start = time.time()
        running_loss = 0.

        # start from 1
        for i, (x, y, m) in enumerate(data_loader, 1):
            # cur [B, L], future [B, L] bit label, mels [B, D, T]
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            # Forward pass
            # [B, L], [B, D, T] -> [B, L, C]
            y_hat = model(x, m)
            if model.mode == 'RAW':
                # [B, L, C] -> [B, C, L, 1]
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y = y.float()
            # [B, L, 1]
            y = y.unsqueeze(-1)

            # Backward pass
            # [B, C, L, 1], [B, L, 1]
            # cross_entropy for RAW
            loss = loss_func(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if backup_every != 0 and step % backup_every == 0 :
                model.checkpoint(str(model_dir), optimizer)

            if save_every != 0 and step % save_every == 0 :
                model.save(str(weights_fpath), optimizer)

            if log_every != 0 and step % log_every == 0 :
                logger.scalar_summary("loss", loss.item(), step)

            total_data=len(data_loader)

            speed=speed
            avg_loss=avg_loss
            k=k
            total_data=total_data
            msg = ("| Epoch: {epoch} ({i}/{total_data}) | " +\
                "Loss: {avg_loss:.4f} | {speed:.1f} " +\
                "steps/s | Step: {k}k | ").format(**vars())
            stream(msg)


        #gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
        #            hp.voc_target, hp.voc_overlap, model_dir)
        print("")
예제 #4
0
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer,
                   train_set, test_set, lr, total_steps):
    # Use same device as model parameters
    device = next(model.parameters()).device

    for g in optimizer.param_groups:
        g['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.to(device), m.to(device), y.to(device)

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                y_hat = data_parallel_workaround(model, x, m)
            else:
                y_hat = model(x, m)

            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            if hp.voc_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.voc_clip_grad_norm)
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0:
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            paths.voc_output)
                ckpt_name = f'wave_step{k}K'
                save_checkpoint('voc',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('voc', paths, model, optimizer, is_silent=True)
        model.log(paths.voc_log, msg)
        print(' ')
예제 #5
0
class VideoAudioGenerator(nn.Module):
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 multigpu=False,
                 lambda_wavenet=0.001,
                 args=None,
                 residual=False,
                 attention_map=None,
                 use_256=False,
                 loss_content=False,
                 test_path=None):
        super(VideoAudioGenerator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)

        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)
        if use_256:
            self.video_decoder = VideoGenerator(use_256=True)
        else:
            self.video_decoder = STAGE2_G(residual=residual)
        self.use_256 = use_256
        self.lambda_wavenet = lambda_wavenet
        self.loss_content = loss_content
        self.multigpu = multigpu
        self.test_path = test_path

        self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims,
                               fc_dims=hparams.voc_fc_dims,
                               bits=hparams.bits,
                               pad=hparams.voc_pad,
                               upsample_factors=hparams.voc_upsample_factors,
                               feat_dims=hparams.num_mels,
                               compute_dims=hparams.voc_compute_dims,
                               res_out_dims=hparams.voc_res_out_dims,
                               res_blocks=hparams.voc_res_blocks,
                               hop_length=hparams.hop_size,
                               sample_rate=hparams.sample_rate,
                               mode=hparams.voc_mode)

        if is_train:
            self.criterionIdt = torch.nn.L1Loss(reduction='mean')
            self.opt_encoder = torch.optim.Adam(self.encoder.parameters(),
                                                lr=lr)
            self.opt_decoder = torch.optim.Adam(itertools.chain(
                self.decoder.parameters(), self.postnet.parameters()),
                                                lr=lr)
            self.opt_video_decoder = torch.optim.Adam(
                self.video_decoder.parameters(), lr=lr)

            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.video_decoder = nn.DataParallel(self.video_decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)

    def optimize_parameters_video(self,
                                  dataloader,
                                  epochs,
                                  device,
                                  display_freq=10,
                                  save_freq=1000,
                                  save_dir="./",
                                  experimentName="Train",
                                  initial_niter=0,
                                  load_model=None):
        writer = SummaryWriter(log_dir="logs/" + experimentName)
        if load_model is not None:
            print("Loading from %s..." % load_model)
            # self.load_state_dict(torch.load(load_model))
            d = torch.load(load_model)
            newdict = d.copy()
            for key, value in d.items():
                newkey = key
                if 'wavenet' in key:
                    newdict[key.replace('wavenet',
                                        'vocoder')] = newdict.pop(key)
                    newkey = key.replace('wavenet', 'vocoder')
                if self.multigpu and 'module' not in key:
                    newdict[newkey.replace('.', '.module.',
                                           1)] = newdict.pop(newkey)
                    newkey = newkey.replace('.', '.module.', 1)
                if newkey not in self.state_dict():
                    newdict.pop(newkey)
            print("Load " + str(len(newdict)) + " parameters!")
            self.load_state_dict(newdict, strict=False)
            print("AutoVC Model Loaded")
        niter = initial_niter
        for epoch in range(epochs):
            self.train()
            for i, data in enumerate(dataloader):
                # print("Processing ..." + str(name))
                speaker, mel, prev, wav, video, video_large = data
                speaker, mel, prev, wav, video, video_large = speaker.to(
                    device), mel.to(device), prev.to(device), wav.to(
                        device), video.to(device), video_large.to(device)
                codes, code_unsample = self.encoder(mel,
                                                    speaker,
                                                    return_unsample=True)

                tmp = []
                for code in codes:
                    tmp.append(
                        code.unsqueeze(1).expand(-1,
                                                 int(mel.size(1) / len(codes)),
                                                 -1))
                code_exp = torch.cat(tmp, dim=1)

                if not self.use_256:
                    v_stage1, v_stage2 = self.video_decoder(code_unsample,
                                                            train=True)
                else:
                    v_stage2 = self.video_decoder(code_unsample)
                mel_outputs = self.decoder(code_exp)
                mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
                mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(
                    2, 1)

                if self.loss_content:
                    _, recons_codes = self.encoder(mel_outputs_postnet,
                                                   speaker,
                                                   return_unsample=True)
                    loss_content = self.criterionIdt(code_unsample,
                                                     recons_codes)
                else:
                    loss_content = torch.from_numpy(np.array(0))

                if not self.use_256:
                    loss_video = self.criterionIdt(v_stage1,
                                                   video) + self.criterionIdt(
                                                       v_stage2, video_large)
                else:
                    loss_video = self.criterionIdt(v_stage2, video_large)

                loss_recon = self.criterionIdt(mel, mel_outputs)
                loss_recon0 = self.criterionIdt(mel, mel_outputs_postnet)
                loss_vocoder = 0

                if not self.multigpu:
                    y_hat = self.vocoder(
                        prev,
                        self.vocoder.pad_tensor(mel_outputs_postnet,
                                                hparams.voc_pad).transpose(
                                                    1, 2))
                else:
                    y_hat = self.vocoder(
                        prev,
                        self.vocoder.module.pad_tensor(
                            mel_outputs_postnet,
                            hparams.voc_pad).transpose(1, 2))
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                # assert (0 <= wav < 2 ** 9).all()
                loss_vocoder = self.vocoder_loss_func(
                    y_hat,
                    wav.unsqueeze(-1).to(device))
                self.opt_vocoder.zero_grad()

                loss = loss_video + loss_recon + loss_recon0 + self.lambda_wavenet * loss_vocoder + loss_content

                self.opt_encoder.zero_grad()
                self.opt_decoder.zero_grad()
                self.opt_video_decoder.zero_grad()
                loss.backward()
                self.opt_encoder.step()
                self.opt_decoder.step()
                self.opt_video_decoder.step()
                self.opt_vocoder.step()

                if niter % display_freq == 0:
                    print("Epoch[%d] Iter[%d] Niter[%d] %s" %
                          (epoch, i, niter, loss.data.item()))
                    writer.add_scalars(
                        'data/Loss', {
                            'loss':
                            loss.data.item(),
                            'loss_video':
                            loss_video.data.item(),
                            'loss_audio':
                            loss_recon0.data.item() + loss_recon.data.item()
                        }, niter)

                if niter % save_freq == 0:
                    torch.cuda.empty_cache()  # Prevent Out of Memory
                    print("Saving and Testing...", end='\t')
                    torch.save(
                        self.state_dict(),
                        save_dir + '/Epoch' + str(epoch).zfill(3) + '_Iter' +
                        str(niter).zfill(8) + ".pkl")
                    # self.load_state_dict(torch.load('params.pkl'))
                    self.test_audiovideo(device, writer, niter)
                    print("Done")
                    self.train()
                torch.cuda.empty_cache()  # Prevent Out of Memory
                niter += 1

    def generate(self, mel, speaker, device='cuda:0'):
        mel, speaker = mel.to(device), speaker.to(device)
        if not self.multigpu:
            codes, code_unsample = self.encoder(mel,
                                                speaker,
                                                return_unsample=True)
        else:
            codes, code_unsample = self.encoder.module(mel,
                                                       speaker,
                                                       return_unsample=True)

        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(-1, int(mel.size(1) / len(codes)),
                                         -1))
        code_exp = torch.cat(tmp, dim=1)

        if not self.multigpu:
            if not self.use_256:
                v_stage1, v_stage2 = self.video_decoder(code_unsample,
                                                        train=True)
            else:
                v_stage2 = self.video_decoder(code_unsample)
                v_stage1 = v_stage2
            mel_outputs = self.decoder(code_exp)
            mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        else:
            if not self.use_256:
                v_stage1, v_stage2 = self.video_decoder.module(code_unsample,
                                                               train=True)
            else:
                v_stage2 = self.video_decoder.module(code_unsample)
                v_stage1 = v_stage2
            mel_outputs = self.decoder.module(code_exp)
            mel_outputs_postnet = self.postnet.module(
                mel_outputs.transpose(2, 1))

        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)

        return mel_outputs_postnet, v_stage1, v_stage2

    def test_video(self, device):
        wav, sr = librosa.load(
            "/mnt/lustre/dengkangle/cmu/datasets/video/obama_test.mp4",
            hparams.sample_rate)
        mel_basis = librosa.filters.mel(hparams.sample_rate,
                                        hparams.n_fft,
                                        n_mels=hparams.num_mels)
        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)

        test_data = np.clip((mel_db + 120) / 125, 0, 1)
        test_data = torch.Tensor(pad_seq(test_data.T,
                                         hparams.freq)).unsqueeze(0).to(device)
        with torch.no_grad():
            codes, code_exp = self.encoder.module(test_data,
                                                  return_unsample=True)
            v_mid, v_hat = self.video_decoder.module(code_exp, train=True)

        reader = imageio.get_reader(
            "/mnt/lustre/dengkangle/cmu/datasets/video/obama_test.mp4",
            'ffmpeg',
            fps=20)
        frames = []
        for i, im in enumerate(reader):
            frames.append(np.array(im).transpose(2, 0, 1))
        frames = (np.array(frames) / 255 - 0.5) / 0.5
        return frames, v_mid[0:1], v_hat[0:1]

    def test_audiovideo(self, device, writer, niter):
        source_path = self.test_path

        mel_basis80 = librosa.filters.mel(hparams.sample_rate,
                                          hparams.n_fft,
                                          n_mels=80)

        wav, sr = librosa.load(source_path, hparams.sample_rate)
        wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize)

        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis80.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)
        source_spec = np.clip((mel_db + 120) / 125, 0, 1)

        source_embed = torch.from_numpy(np.array([0, 1])).float().unsqueeze(0)
        source_wav = wav

        source_spec = torch.Tensor(pad_seq(source_spec.T,
                                           hparams.freq)).unsqueeze(0)
        # print(source_spec.shape)

        with torch.no_grad():
            generated_spec, v_mid, v_hat = self.generate(
                source_spec, source_embed, device)

        generated_spec, v_mid, v_hat = generated_spec.cpu(), v_mid.cpu(
        ), v_hat.cpu()

        print("Generating Wavfile...")
        with torch.no_grad():
            if not self.multigpu:
                generated_wav = inv_preemphasis(
                    self.vocoder.generate(generated_spec.to(device).transpose(
                        2, 1),
                                          False,
                                          None,
                                          None,
                                          mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)

            else:
                generated_wav = inv_preemphasis(
                    self.vocoder.module.generate(
                        generated_spec.to(device).transpose(2, 1),
                        False,
                        None,
                        None,
                        mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)

        writer.add_video('generated', (v_hat.numpy() + 1) / 2,
                         global_step=niter)
        writer.add_video('mid', (v_mid.numpy() + 1) / 2, global_step=niter)
        writer.add_audio('ground_truth',
                         source_wav,
                         niter,
                         sample_rate=hparams.sample_rate)
        writer.add_audio('generated_wav',
                         generated_wav,
                         niter,
                         sample_rate=hparams.sample_rate)
예제 #6
0
class Generator(nn.Module):
    """Generator network."""
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 loss_content=True,
                 discriminator=False,
                 multigpu=False,
                 lambda_gan=0.0001,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None):
        super(Generator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)
        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)

        if discriminator:
            self.dis = PatchDiscriminator(n_class=num_speakers)
            self.dis_criterion = GANLoss(use_lsgan=use_lsgan,
                                         tensor=torch.cuda.FloatTensor)
        else:
            self.dis = None

        self.loss_content = loss_content
        self.lambda_gan = lambda_gan
        self.lambda_wavenet = lambda_wavenet

        self.multigpu = multigpu
        self.prepare_test(dim_spec, test_path_source, test_path_target)

        self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims,
                               fc_dims=hparams.voc_fc_dims,
                               bits=hparams.bits,
                               pad=hparams.voc_pad,
                               upsample_factors=hparams.voc_upsample_factors,
                               feat_dims=hparams.num_mels,
                               compute_dims=hparams.voc_compute_dims,
                               res_out_dims=hparams.voc_res_out_dims,
                               res_blocks=hparams.voc_res_blocks,
                               hop_length=hparams.hop_size,
                               sample_rate=hparams.sample_rate,
                               mode=hparams.voc_mode)

        if is_train:
            self.criterionIdt = torch.nn.L1Loss(reduction='mean')
            self.opt_encoder = torch.optim.Adam(self.encoder.parameters(),
                                                lr=lr)
            self.opt_decoder = torch.optim.Adam(itertools.chain(
                self.decoder.parameters(), self.postnet.parameters()),
                                                lr=lr)
            if discriminator:
                self.opt_dis = torch.optim.Adam(self.dis.parameters(), lr=lr)
            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
            if self.dis is not None:
                self.dis = nn.DataParallel(self.dis)

    def prepare_test(self, dim_spec, source_path=None, target_path=None):
        if source_path is None:
            source_path = "/mnt/lustre/dengkangle/cmu/datasets/audio/test/trump_02.wav"
        if target_path is None:
            target_path = "/mnt/lustre/dengkangle/cmu/datasets/audio/test/female.wav"
        # source_path = "/home/kangled/datasets/audio/Chaplin_01.wav"
        # target_path = "/home/kangled/datasets/audio/Obama_01.wav"

        mel_basis80 = librosa.filters.mel(hparams.sample_rate,
                                          hparams.n_fft,
                                          n_mels=80)

        wav, sr = librosa.load(source_path, hparams.sample_rate)
        wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis80.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)
        source_spec = np.clip((mel_db + 120) / 125, 0, 1)
        # source_spec = mel_spec

        self.source_embed = torch.from_numpy(np.array([0, 1
                                                       ])).float().unsqueeze(0)
        self.source_wav = wav

        wav, sr = librosa.load(target_path, hparams.sample_rate)
        wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis80.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)
        target_spec = np.clip((mel_db + 120) / 125, 0, 1)
        # target_spec = mel_spec

        self.target_embed = torch.from_numpy(np.array([1, 0
                                                       ])).float().unsqueeze(0)
        self.target_wav = wav

        self.source_spec = torch.Tensor(pad_seq(source_spec.T,
                                                hparams.freq)).unsqueeze(0)
        self.target_spec = torch.Tensor(pad_seq(target_spec.T,
                                                hparams.freq)).unsqueeze(0)

    def test_fixed(self, device):
        with torch.no_grad():
            t2s_spec = self.conversion(self.target_embed, self.source_embed,
                                       self.target_spec, device).cpu()
            s2s_spec = self.conversion(self.source_embed, self.source_embed,
                                       self.source_spec, device).cpu()
            s2t_spec = self.conversion(self.source_embed, self.target_embed,
                                       self.source_spec, device).cpu()
            t2t_spec = self.conversion(self.target_embed, self.target_embed,
                                       self.target_spec, device).cpu()

        ret_dic = {}
        ret_dic['A_fake_griffin'], sr = mel2wav(s2t_spec.numpy().squeeze(0).T)
        ret_dic['B_fake_griffin'], sr = mel2wav(t2s_spec.numpy().squeeze(0).T)
        ret_dic['A'] = self.source_wav
        ret_dic['B'] = self.target_wav

        with torch.no_grad():
            if not self.multigpu:
                ret_dic['A_fake_w'] = inv_preemphasis(
                    self.vocoder.generate(s2t_spec.to(device).transpose(2, 1),
                                          False,
                                          None,
                                          None,
                                          mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)
                ret_dic['B_fake_w'] = inv_preemphasis(
                    self.vocoder.generate(t2s_spec.to(device).transpose(2, 1),
                                          False,
                                          None,
                                          None,
                                          mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)
            else:
                ret_dic['A_fake_w'] = inv_preemphasis(
                    self.vocoder.module.generate(s2t_spec.to(device).transpose(
                        2, 1),
                                                 False,
                                                 None,
                                                 None,
                                                 mu_law=True),
                    hparams.preemphasis, hparams.preemphasize)
                ret_dic['B_fake_w'] = inv_preemphasis(
                    self.vocoder.module.generate(t2s_spec.to(device).transpose(
                        2, 1),
                                                 False,
                                                 None,
                                                 None,
                                                 mu_law=True),
                    hparams.preemphasis, hparams.preemphasize)
        return ret_dic, sr

    def conversion(self, speaker_org, speaker_trg, spec, device, speed=1):
        speaker_org, speaker_trg, spec = speaker_org.to(
            device), speaker_trg.to(device), spec.to(device)
        if not self.multigpu:
            codes = self.encoder(spec, speaker_org)
        else:
            codes = self.encoder.module(spec, speaker_org)
        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(
                    -1, int(speed * spec.size(1) / len(codes)), -1))
        code_exp = torch.cat(tmp, dim=1)
        encoder_outputs = torch.cat((code_exp, speaker_trg.unsqueeze(1).expand(
            -1, code_exp.size(1), -1)),
                                    dim=-1)
        mel_outputs = self.decoder(
            code_exp) if not self.multigpu else self.decoder.module(code_exp)

        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)
        return mel_outputs_postnet

    def optimize_parameters(self,
                            dataloader,
                            epochs,
                            device,
                            display_freq=10,
                            save_freq=1000,
                            save_dir="./",
                            experimentName="Train",
                            load_model=None,
                            initial_niter=0):
        writer = SummaryWriter(log_dir="logs/" + experimentName)
        if load_model is not None:
            print("Loading from %s..." % load_model)
            # self.load_state_dict(torch.load(load_model))
            d = torch.load(load_model)
            newdict = d.copy()
            for key, value in d.items():
                newkey = key
                if 'wavenet' in key:
                    newdict[key.replace('wavenet',
                                        'vocoder')] = newdict.pop(key)
                    newkey = key.replace('wavenet', 'vocoder')
                if self.multigpu and 'module' not in key:
                    newdict[newkey.replace('.', '.module.',
                                           1)] = newdict.pop(newkey)
                    newkey = newkey.replace('.', '.module.', 1)
                if newkey not in self.state_dict():
                    newdict.pop(newkey)
            self.load_state_dict(newdict)
            print("AutoVC Model Loaded")
        niter = initial_niter
        for epoch in range(epochs):
            self.train()
            for i, data in enumerate(dataloader):
                speaker_org, spec, prev, wav = data
                loss_dict, loss_dict_discriminator, loss_dict_wavenet = \
                    self.train_step(spec.to(device), speaker_org.to(device), prev=prev.to(device), wav=wav.to(device), device=device)
                if niter % display_freq == 0:
                    print("Epoch[%d] Iter[%d] Niter[%d] %s %s %s" %
                          (epoch, i, niter, loss_dict, loss_dict_discriminator,
                           loss_dict_wavenet))
                    writer.add_scalars('data/Loss', loss_dict, niter)
                    if loss_dict_discriminator != {}:
                        writer.add_scalars('data/discriminator',
                                           loss_dict_discriminator, niter)
                    if loss_dict_wavenet != {}:
                        writer.add_scalars('data/wavenet', loss_dict_wavenet,
                                           niter)
                if niter % save_freq == 0:
                    print("Saving and Testing...", end='\t')
                    torch.save(
                        self.state_dict(),
                        save_dir + '/Epoch' + str(epoch).zfill(3) + '_Iter' +
                        str(niter).zfill(8) + ".pkl")
                    # self.load_state_dict(torch.load('params.pkl'))
                    if len(dataloader) >= 2:
                        wav_dic, sr = self.test_fixed(device)
                        for key, wav in wav_dic.items():
                            # print(wav.shape)
                            writer.add_audio(key, wav, niter, sample_rate=sr)
                    print("Done")
                    self.train()
                torch.cuda.empty_cache()  # Prevent Out of Memory
                niter += 1

    def train_step(self,
                   x,
                   c_org,
                   mask=None,
                   mask_code=None,
                   prev=None,
                   wav=None,
                   ret_content=False,
                   retain_graph=False,
                   device='cuda:0'):
        codes = self.encoder(x, c_org)
        # print(codes[0].shape)
        content = torch.cat([code.unsqueeze(1) for code in codes], dim=1)
        # print("content shape", content.shape)
        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1))
        code_exp = torch.cat(tmp, dim=1)

        encoder_outputs = torch.cat(
            (code_exp, c_org.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1)

        mel_outputs = self.decoder(code_exp)

        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)

        loss_dict, loss_dict_discriminator, loss_dict_wavenet = {}, {}, {}

        loss_recon = self.criterionIdt(x, mel_outputs)
        loss_recon0 = self.criterionIdt(x, mel_outputs_postnet)
        loss_dict['recon'], loss_dict['recon0'] = loss_recon.data.item(
        ), loss_recon0.data.item()

        if self.loss_content:
            recons_codes = self.encoder(mel_outputs_postnet, c_org)
            recons_content = torch.cat(
                [code.unsqueeze(1) for code in recons_codes], dim=1)
            if mask is not None:
                loss_content = self.criterionIdt(
                    content.masked_select(mask_code.byte()),
                    recons_content.masked_select(mask_code.byte()))
            else:
                loss_content = self.criterionIdt(content, recons_content)
            loss_dict['content'] = loss_content.data.item()
        else:
            loss_content = torch.from_numpy(np.array(0))

        loss_gen, loss_dis, loss_vocoder = [torch.from_numpy(np.array(0))] * 3
        fake_mel = None
        if self.dis:
            # true_label = torch.from_numpy(np.ones(shape=(x.shape[0]))).to('cuda:0').long()
            # false_label = torch.from_numpy(np.zeros(shape=(x.shape[0]))).to('cuda:0').long()

            flip_speaker = 1 - c_org
            fake_mel = self.conversion(c_org, flip_speaker, x, device)

            loss_dis = self.dis_criterion(self.dis(x),
                                          True) + self.dis_criterion(
                                              self.dis(fake_mel), False)
            # +  self.dis_criterion(self.dis(mel_outputs_postnet), False)

            self.opt_dis.zero_grad()
            loss_dis.backward(retain_graph=True)
            self.opt_dis.step()
            loss_gen = self.dis_criterion(self.dis(fake_mel), True)
            # + self.dis_criterion(self.dis(mel_outputs_postnet), True)
            loss_dict_discriminator['dis'], loss_dict_discriminator[
                'gen'] = loss_dis.data.item(), loss_gen.data.item()

        if not self.multigpu:
            y_hat = self.vocoder(
                prev,
                self.vocoder.pad_tensor(mel_outputs_postnet,
                                        hparams.voc_pad).transpose(1, 2))
        else:
            y_hat = self.vocoder(
                prev,
                self.vocoder.module.pad_tensor(mel_outputs_postnet,
                                               hparams.voc_pad).transpose(
                                                   1, 2))
        y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
        # assert (0 <= wav < 2 ** 9).all()
        loss_vocoder = self.vocoder_loss_func(y_hat,
                                              wav.unsqueeze(-1).to(device))
        self.opt_vocoder.zero_grad()

        Loss = loss_recon + loss_recon0 + loss_content + \
               self.lambda_gan * loss_gen + self.lambda_wavenet * loss_vocoder
        loss_dict['total'] = Loss.data.item()
        self.opt_encoder.zero_grad()
        self.opt_decoder.zero_grad()
        Loss.backward(retain_graph=retain_graph)
        self.opt_encoder.step()
        self.opt_decoder.step()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.vocoder.parameters(),
                                                   65504.0)
        self.opt_vocoder.step()

        if ret_content:
            return loss_recon, loss_recon0, loss_content, Loss, content
        return loss_dict, loss_dict_discriminator, loss_dict_wavenet
예제 #7
0
class Generator(nn.Module):
    """Generator network."""
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 decoder_type='simple',
                 vocoder_type='wavenet',
                 encoder_type='default',
                 separate_encoder=True,
                 loss_content=True,
                 discriminator=False,
                 dis_type='patch',
                 multigpu=False,
                 cycle=False,
                 lambda_cycle=1,
                 num_speakers=-1,
                 idt_type='L2',
                 use_lsgan=True,
                 lambda_gan=0.0001,
                 train_wavenet=False,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None,
                 attention=False,
                 residual=False):
        super(Generator, self).__init__()

        if encoder_type == 'default':
            self.encoder = Encoder(dim_neck, dim_emb, freq)
        elif encoder_type == 'nospeaker' or encoder_type == 'single':
            self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)
        elif encoder_type == 'multiencoder':
            self.encoder = MultiEncoder(num_speakers, dim_neck, freq,
                                        separate_encoder)

        if encoder_type == 'multiencoder' or encoder_type == 'single':
            self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        elif decoder_type == 'simple':
            self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
        elif decoder_type == 'tacotron':
            self.decoder = TacotronDecoder(hparams)
        elif decoder_type == 'multidecoder':
            self.decoder = MultiDecoder(num_speakers, dim_neck, dim_pre,
                                        multigpu)
        elif decoder_type == 'video':
            # self.decoder = VideoGenerator()
            self.decoder = STAGE2_G(residual=residual)
        self.postnet = Postnet(num_mel=dim_spec)
        if discriminator:
            if dis_type == 'patch':
                self.dis = PatchDiscriminator(n_class=num_speakers)
            else:
                self.dis = SpeakerDiscriminator()
            # self.dis_criterion = nn.CrossEntropyLoss(reduction='mean')
            self.dis_criterion = GANLoss(use_lsgan=use_lsgan,
                                         tensor=torch.cuda.FloatTensor)
        else:
            self.dis = None

        self.encoder_type = encoder_type
        self.decoder_type = decoder_type
        self.vocoder_type = vocoder_type
        self.loss_content = loss_content
        self.cycle = cycle
        self.lambda_cycle = lambda_cycle
        self.lambda_gan = lambda_gan
        self.lambda_wavenet = lambda_wavenet
        self.attention = attention

        self.multigpu = multigpu

        self.train_vocoder = train_wavenet
        if self.train_vocoder:
            self.vocoder = WaveRNN(
                rnn_dims=hparams.voc_rnn_dims,
                fc_dims=hparams.voc_fc_dims,
                bits=hparams.bits,
                pad=hparams.voc_pad,
                upsample_factors=hparams.voc_upsample_factors,
                feat_dims=hparams.num_mels,
                compute_dims=hparams.voc_compute_dims,
                res_out_dims=hparams.voc_res_out_dims,
                res_blocks=hparams.voc_res_blocks,
                hop_length=hparams.hop_size,
                sample_rate=hparams.sample_rate,
                mode=hparams.voc_mode)
            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
            if self.dis is not None:
                self.dis = nn.DataParallel(self.dis)

    def forward(self, x, c_org, c_trg):

        codes = self.encoder(x, c_org)
        if c_trg is None:
            return torch.cat(codes, dim=-1)

        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1))
        code_exp = torch.cat(tmp, dim=1)

        encoder_outputs = torch.cat(
            (code_exp, c_trg.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1)
        # (batch, T, 256+dim_neck)
        mel_outputs = self.decoder(encoder_outputs)

        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)

        mel_outputs = mel_outputs.unsqueeze(1)
        mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1)

        return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1)

    def conversion(self, speaker_org, speaker_trg, spec, device, speed=1):
        speaker_org, speaker_trg, spec = speaker_org.to(
            device), speaker_trg.to(device), spec.to(device)
        if self.encoder_type == 'multiencoder':
            codes = self.encoder(spec, speaker_trg)
        else:
            if not self.multigpu:
                codes = self.encoder(spec, speaker_org)
            else:
                codes = self.encoder.module(spec, speaker_org)
        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(
                    -1, int(speed * spec.size(1) / len(codes)), -1))
        code_exp = torch.cat(tmp, dim=1)
        if self.attention:
            code_exp = self.phonemeToken(code_exp)
        encoder_outputs = torch.cat((code_exp, speaker_trg.unsqueeze(1).expand(
            -1, code_exp.size(1), -1)),
                                    dim=-1)
        if self.encoder_type == 'multiencoder' or self.encoder_type == 'single':
            mel_outputs = self.decoder(
                code_exp) if not self.multigpu else self.decoder.module(
                    code_exp)
        elif self.decoder_type == 'simple':
            mel_outputs = self.decoder(encoder_outputs)
        elif self.decoder_type == 'tacotron':
            try:
                mel_outputs, _, alignments = self.decoder.inference(
                    memory=encoder_outputs)
            except:
                mel_outputs, _, alignments = self.decoder.module.inference(
                    memory=encoder_outputs)
            mel_outputs.transpose_(1, 2)
        elif self.decoder_type == 'multidecoder':
            mel_outputs = self.decoder(code_exp, speaker_trg)

        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)
        return mel_outputs_postnet
예제 #8
0
class VideoAudioGenerator(nn.Module):
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 vocoder_type='wavenet',
                 multigpu=False,
                 num_speakers=-1,
                 idt_type='L2',
                 train_wavenet=False,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None,
                 residual=False,
                 attention_map=None,
                 use_256=False):
        super(VideoAudioGenerator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)

        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)
        if use_256:
            self.video_decoder = VideoGenerator(use_256=True)
        else:
            self.video_decoder = STAGE2_G(residual=residual)
        self.use_256 = use_256
        self.vocoder_type = vocoder_type
        self.lambda_wavenet = lambda_wavenet

        self.multigpu = multigpu
        # self.prepare_test(dim_spec, test_path_source, test_path_target)

        self.train_vocoder = train_wavenet
        if self.train_vocoder:
            if vocoder_type == 'wavenet' or vocoder_type == 'griffin':
                self.vocoder = WaveRNN(
                    rnn_dims=hparams.voc_rnn_dims,
                    fc_dims=hparams.voc_fc_dims,
                    bits=hparams.bits,
                    pad=hparams.voc_pad,
                    upsample_factors=hparams.voc_upsample_factors,
                    feat_dims=hparams.num_mels,
                    compute_dims=hparams.voc_compute_dims,
                    res_out_dims=hparams.voc_res_out_dims,
                    res_blocks=hparams.voc_res_blocks,
                    hop_length=hparams.hop_size,
                    sample_rate=hparams.sample_rate,
                    mode=hparams.voc_mode)
                self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                    lr=hparams.voc_lr)
                self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if attention_map is not None:
            self.attention_map_large = np.load(attention_map)
            self.attention_map = cv2.resize(self.attention_map_large,
                                            dsize=(128, 128),
                                            interpolation=cv2.INTER_CUBIC)
            # self.attention_map_large = self.attention_map_large.astype(np.float64)
            # self.attention_map = self.attention_map.astype(np.float64)
            self.attention_map_large = torch.from_numpy(
                self.attention_map_large /
                self.attention_map_large.max()).float()
            self.attention_map = torch.from_numpy(
                self.attention_map / self.attention_map.max()).float()
            self.criterionVideo = torch.nn.L1Loss(reduction='none')
        else:
            self.attention_map = None

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.video_decoder = nn.DataParallel(self.video_decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)

    def generate(self, mel, speaker, device='cuda:0'):
        mel, speaker = mel.to(device), speaker.to(device)
        if not self.multigpu:
            codes, code_unsample = self.encoder(mel,
                                                speaker,
                                                return_unsample=True)
        else:
            codes, code_unsample = self.encoder.module(mel,
                                                       speaker,
                                                       return_unsample=True)

        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(-1, int(mel.size(1) / len(codes)),
                                         -1))
        code_exp = torch.cat(tmp, dim=1)

        if not self.multigpu:
            if not self.use_256:
                v_stage1, v_stage2 = self.video_decoder(code_unsample,
                                                        train=True)
            else:
                v_stage2 = self.video_decoder(code_unsample)
                v_stage1 = v_stage2
            mel_outputs = self.decoder(code_exp)
            mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        else:
            if not self.use_256:
                v_stage1, v_stage2 = self.video_decoder.module(code_unsample,
                                                               train=True)
            else:
                v_stage2 = self.video_decoder.module(code_unsample)
                v_stage1 = v_stage2
            mel_outputs = self.decoder.module(code_exp)
            mel_outputs_postnet = self.postnet.module(
                mel_outputs.transpose(2, 1))

        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)

        return mel_outputs_postnet, v_stage1, v_stage2
예제 #9
0
def train(run_id: str, models_dir: Path, metadata_path: Path,
          weights_path: Path, ground_truth: bool, save_every: int,
          backup_every: int, force_restart: bool):
    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    # Instantiate the model
    print("Initializing the model...")
    model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                    fc_dims=hp.voc_fc_dims,
                    bits=hp.bits,
                    pad=hp.voc_pad,
                    upsample_factors=hp.voc_upsample_factors,
                    feat_dims=hp.num_mels,
                    compute_dims=hp.voc_compute_dims,
                    res_out_dims=hp.voc_res_out_dims,
                    res_blocks=hp.voc_res_blocks,
                    hop_length=hp.hop_length,
                    sample_rate=hp.sample_rate,
                    mode=hp.voc_mode).cuda()

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())
    for p in optimizer.param_groups:
        p["lr"] = hp.voc_lr
    loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss

    # Load the weights
    model_dir = models_dir.joinpath(run_id)
    model_dir.mkdir(exist_ok=True)
    weights_fpath = weights_path
    metadata_fpath = metadata_path

    if force_restart:
        print("\nStarting the training of WaveRNN from scratch\n")
        model.save(weights_fpath, optimizer)
    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(weights_fpath, optimizer)
        print("WaveRNN weights loaded from step %d" % model.step)

    # Initialize the dataset

    dataset = VocoderDataset(metadata_fpath)

    test_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True)

    # Begin the training
    simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr),
                  ('Sequence Len', hp.voc_seq_len)])

    epoch_start = int(
        (model.step - 428000) * 110 / dataset.get_number_of_samples())
    epoch_end = 200

    log_path = os.path.join(models_dir, "logs")
    if not os.path.isdir(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_path)
    print("Log path : " + log_path)

    print("Starting from epoch: " + str(epoch_start))

    for epoch in range(epoch_start, epoch_start + epoch_end):
        data_loader = DataLoader(dataset,
                                 collate_fn=collate_vocoder,
                                 batch_size=hp.voc_batch_size,
                                 num_workers=2,
                                 shuffle=True,
                                 pin_memory=True)
        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(data_loader, 1):
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            # Forward pass
            y_hat = model(x, m)
            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y = y.float()
            y = y.unsqueeze(-1)

            # Backward pass
            loss = loss_func(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if backup_every != 0 and step % backup_every == 0:
                model.checkpoint(model_dir, optimizer)

            # if save_every != 0 and step % save_every == 0 :
            #     model.save(weights_fpath, optimizer)

            if step % 500 == 0:
                writer.add_scalar('Loss/train', avg_loss,
                                  round(step / 1000, 1))
                msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
                    f"Loss: {avg_loss:.4f} | {speed:.1f} " \
                    f"steps/s | Step: {k}k | "
                print(msg, flush=True)

            if step % 15000 == 0:
                gen_testset(model, test_loader, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            model_dir)
                gen_meltest(model, hp.voc_gen_batched, hp.voc_target,
                            hp.voc_overlap, model_dir)