Пример #1
0
def load_model(weights_fpath, verbose=False):
    global _model, _device

    if verbose:
        print("Building Wave-RNN")
    _model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                     fc_dims=hp.voc_fc_dims,
                     bits=hp.bits,
                     pad=hp.voc_pad,
                     upsample_factors=hp.voc_upsample_factors,
                     feat_dims=hp.num_mels,
                     compute_dims=hp.voc_compute_dims,
                     res_out_dims=hp.voc_res_out_dims,
                     res_blocks=hp.voc_res_blocks,
                     hop_length=hp.hop_length,
                     sample_rate=hp.sample_rate,
                     mode=hp.voc_mode)

    if torch.cuda.is_available():
        _model = _model.cuda()
        _device = torch.device('cuda')
    else:
        _device = torch.device('cpu')

    if verbose:
        print("Loading model weights at %s" % weights_fpath)
    checkpoint = torch.load(weights_fpath, _device)
    _model.load_state_dict(checkpoint['model_state'])
    _model.eval()

    print(type(_model))
def load_model(weights_fpath, verbose=True):
    global _model
    
    if verbose:
        print("Building Wave-RNN")
    _model = WaveRNN(
        rnn_dims=hp.voc_rnn_dims,
        fc_dims=hp.voc_fc_dims,
        bits=hp.bits,
        pad=hp.voc_pad,
        upsample_factors=hp.voc_upsample_factors,
        feat_dims=hp.num_mels,
        compute_dims=hp.voc_compute_dims,
        res_out_dims=hp.voc_res_out_dims,
        res_blocks=hp.voc_res_blocks,
        hop_length=hp.hop_length,
        sample_rate=hp.sample_rate,
        mode=hp.voc_mode
    ).cuda()
    
    if verbose:
        print("Loading model weights at %s" % weights_fpath)
    checkpoint = torch.load(weights_fpath)
    _model.load_state_dict(checkpoint['model_state'])
    _model.eval()
Пример #3
0
    def load_from(self, weights_fpath, verbose=True):
        if verbose:
            print("Building Wave-RNN")
        self._model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                              fc_dims=hp.voc_fc_dims,
                              bits=hp.bits,
                              pad=hp.voc_pad,
                              upsample_factors=hp.voc_upsample_factors,
                              feat_dims=hp.num_mels,
                              compute_dims=hp.voc_compute_dims,
                              res_out_dims=hp.voc_res_out_dims,
                              res_blocks=hp.voc_res_blocks,
                              hop_length=hp.hop_length,
                              sample_rate=hp.sample_rate,
                              mode=hp.voc_mode)  #.cuda()

        if verbose:
            print("Loading model weights at %s" % weights_fpath)
        checkpoint = torch.load(weights_fpath,
                                map_location=torch.device('cpu'))
        self._model.load_state_dict(checkpoint['model_state'])
        self._model.eval()
Пример #4
0
def load_model(weights_fpath, verbose=True):
    global _model

    if verbose:
        print("Building Wave-RNN")
    _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                     fc_dims=hp.voc_fc_dims,
                     bits=hp.bits,
                     pad=hp.voc_pad,
                     upsample_factors=hp.voc_upsample_factors,
                     feat_dims=hp.num_mels,
                     compute_dims=hp.voc_compute_dims,
                     res_out_dims=hp.voc_res_out_dims,
                     res_blocks=hp.voc_res_blocks,
                     hop_length=hp.hop_length,
                     sample_rate=hp.sample_rate,
                     mode=hp.voc_mode).to(_device)

    if verbose:
        print("Loading model weights at %s" % weights_fpath)
    checkpoint = torch.load(str(weights_fpath), map_location=_device)
    _model.load_state_dict(checkpoint['model_state'])
    _model.eval()
Пример #5
0
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path,
          ground_truth: bool, save_every: int, backup_every: int,
          force_restart: bool):
    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    # Instantiate the model
    print("Initializing the model...")
    model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                    fc_dims=hp.voc_fc_dims,
                    bits=hp.bits,
                    pad=hp.voc_pad,
                    upsample_factors=hp.voc_upsample_factors,
                    feat_dims=hp.num_mels,
                    compute_dims=hp.voc_compute_dims,
                    res_out_dims=hp.voc_res_out_dims,
                    res_blocks=hp.voc_res_blocks,
                    hop_length=hp.hop_length,
                    sample_rate=hp.sample_rate,
                    mode=hp.voc_mode).cuda()

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())
    for p in optimizer.param_groups:
        p["lr"] = hp.voc_lr
    loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss

    # Load the weights
    model_dir = models_dir.joinpath(run_id)
    model_dir.mkdir(exist_ok=True)
    weights_fpath = model_dir.joinpath(run_id + ".pt")
    if force_restart or not weights_fpath.exists():
        print("\nStarting the training of WaveRNN from scratch\n")
        model.save(weights_fpath, optimizer)
    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(weights_fpath, optimizer)
        print("WaveRNN weights loaded from step %d" % model.step)

    # Initialize the dataset
    metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
        voc_dir.joinpath("synthesized.txt")
    mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath(
        "mels_gta")
    wav_dir = syn_dir.joinpath("audio")
    dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
    test_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True)

    # Begin the training
    simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr),
                  ('Sequence Len', hp.voc_seq_len)])

    for epoch in range(1, 350):
        data_loader = DataLoader(dataset,
                                 collate_fn=collate_vocoder,
                                 batch_size=hp.voc_batch_size,
                                 num_workers=2,
                                 shuffle=True,
                                 pin_memory=True)
        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(data_loader, 1):
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            # Forward pass
            y_hat = model(x, m)
            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y = y.float()
            y = y.unsqueeze(-1)
            print("y shape:", y.shape)
            print("y_hat shape:", y_hat.shape)
            # Backward pass
            loss = loss_func(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if backup_every != 0 and step % backup_every == 0:
                model.checkpoint(model_dir, optimizer)

            if save_every != 0 and step % save_every == 0:
                model.save(weights_fpath, optimizer)

            msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
                f"Loss: {avg_loss:.4f} | {speed:.1f} " \
                f"steps/s | Step: {k}k | "
            stream(msg)

        gen_testset(model, test_loader, hp.voc_gen_at_checkpoint,
                    hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                    model_dir)
        print("")
Пример #6
0
def main():

    # Parse Arguments
    parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
    parser.add_argument('--lr',
                        '-l',
                        type=float,
                        help='[float] override hparams.py learning rate')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        help='[int] override hparams.py batch size')
    parser.add_argument('--force_train',
                        '-f',
                        action='store_true',
                        help='Forces the model to train past total steps')
    parser.add_argument('--gta',
                        '-g',
                        action='store_true',
                        help='train wavernn on GTA features')
    parser.add_argument(
        '--force_cpu',
        '-c',
        action='store_true',
        help='Forces CPU-only training, even when in CUDA capable environment')
    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')
    args = parser.parse_args()

    # Set hyperparameters
    hp.training_files = "tacotron2/filelists/transcripts_korean_final_final.txt"
    hp.validation_files = "tacotron2/filelists/transcripts_korean_final_validate.txt"
    hp.filter_length = 1024
    hp.n_mel_channels = 80
    hp.sampling_rate = 16000
    hp.mel_fmin = 0.0
    hp.mel_fmax = 8000.0
    hp.max_wav_value = 32768.0
    hp.n_frames_per_step = 1
    hp.configure(args.hp_file)  # load hparams from file
    if args.lr is None:
        args.lr = hp.voc_lr
    if args.batch_size is None:
        args.batch_size = hp.voc_batch_size

    paths = Paths("../data/", hp.voc_model_id, hp.tts_model_id)

    batch_size = 64
    force_train = args.force_train
    train_gta = args.gta
    lr = args.lr

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
        if batch_size % torch.cuda.device_count() != 0:
            raise ValueError(
                '`batch_size` must be evenly divisible by n_gpus!')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    print('\nInitialising Model...\n')

    # Instantiate WaveRNN Model
    voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                        fc_dims=hp.voc_fc_dims,
                        bits=hp.bits,
                        pad=hp.voc_pad,
                        upsample_factors=hp.voc_upsample_factors,
                        feat_dims=hp.num_mels,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).to(device)

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    optimizer = optim.Adam(voc_model.parameters())
    restore_checkpoint('voc',
                       paths,
                       voc_model,
                       optimizer,
                       create_if_missing=True)

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size,
                                               train_gta, hp)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

    voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set,
                   lr, total_steps)

    print('Training Complete.')
    print(
        'To continue training increase voc_total_steps in hparams.py or use --force_train'
    )
Пример #7
0
def train(run_id='',
        syn_dir=None, voc_dirs=[], mel_dir_name='', models_dir=None, log_dir='',
        ground_truth=False,
        save_every=1000, backup_every=1000, log_every=1000,
        force_restart=False, total_epochs=10000, logger=None):
    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    # Instantiate the model
    print("Initializing the model...")
    model = WaveRNN(
        rnn_dims=hp.voc_rnn_dims, # 512
        fc_dims=hp.voc_fc_dims, # 512
        bits=hp.bits, # 9
        pad=hp.voc_pad, # 2
        upsample_factors=hp.voc_upsample_factors, # (3, 4, 5, 5) -> 300, (5,5,12)?
        feat_dims=hp.num_mels, # 80
        compute_dims=hp.voc_compute_dims, # 128
        res_out_dims=hp.voc_res_out_dims, # 128
        res_blocks=hp.voc_res_blocks, # 10
        hop_length=hp.hop_length, # 300
        sample_rate=hp.sample_rate, # 24000
        mode=hp.voc_mode # RAW (or MOL)
    ).cuda()

    # hp.apply_preemphasis in VocoderDataset
    # hp.mu_law in VocoderDataset
    # hp.voc_seq_len in VocoderDataset
    # hp.voc_lr in optimizer
    # hp.voc_batch_size for train

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())
    for p in optimizer.param_groups:
        p["lr"] = hp.voc_lr # 0.0001
    loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss

    # Load the weights
    model_dir = models_dir.joinpath(run_id) # gta_model/gtaxxxx
    model_dir.mkdir(exist_ok=True)
    weights_fpath = model_dir.joinpath(run_id + ".pt") # gta_model/gtaxxx/gtaxxx.pt
    if force_restart or not weights_fpath.exists():
        print("\nStarting the training of WaveRNN from scratch\n")
        model.save(str(weights_fpath), optimizer)
    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(str(weights_fpath), optimizer)
        print("WaveRNN weights loaded from step %d" % model.step)

    # Initialize the dataset
    #metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
    #    voc_dir.joinpath("synthesized.txt")
    #mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
    #wav_dir = syn_dir.joinpath("audio")
    #dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
    #dataset = VocoderDataset(str(voc_dir), 'mels-gta-1099579078086', 'audio')
    dataset = VocoderDataset([str(voc_dir) for voc_dir in voc_dirs], mel_dir_name, 'audio')
    #test_loader = DataLoader(dataset,
    #                         batch_size=1,
    #                         shuffle=True,
    #                         pin_memory=True)

    # Begin the training
    simple_table([('Batch size', hp.voc_batch_size),
                  ('LR', hp.voc_lr),
                  ('Sequence Len', hp.voc_seq_len)])

    for epoch in range(1, total_epochs):
        data_loader = DataLoader(dataset,
                                 collate_fn=collate_vocoder,
                                 batch_size=hp.voc_batch_size,
                                 num_workers=30,
                                 shuffle=True,
                                 pin_memory=True)
        start = time.time()
        running_loss = 0.

        # start from 1
        for i, (x, y, m) in enumerate(data_loader, 1):
            # cur [B, L], future [B, L] bit label, mels [B, D, T]
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            # Forward pass
            # [B, L], [B, D, T] -> [B, L, C]
            y_hat = model(x, m)
            if model.mode == 'RAW':
                # [B, L, C] -> [B, C, L, 1]
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y = y.float()
            # [B, L, 1]
            y = y.unsqueeze(-1)

            # Backward pass
            # [B, C, L, 1], [B, L, 1]
            # cross_entropy for RAW
            loss = loss_func(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if backup_every != 0 and step % backup_every == 0 :
                model.checkpoint(str(model_dir), optimizer)

            if save_every != 0 and step % save_every == 0 :
                model.save(str(weights_fpath), optimizer)

            if log_every != 0 and step % log_every == 0 :
                logger.scalar_summary("loss", loss.item(), step)

            total_data=len(data_loader)

            speed=speed
            avg_loss=avg_loss
            k=k
            total_data=total_data
            msg = ("| Epoch: {epoch} ({i}/{total_data}) | " +\
                "Loss: {avg_loss:.4f} | {speed:.1f} " +\
                "steps/s | Step: {k}k | ").format(**vars())
            stream(msg)


        #gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
        #            hp.voc_target, hp.voc_overlap, model_dir)
        print("")
Пример #8
0
    gta = args.gta

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    print('\nInitialising Model...\n')

    model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                    fc_dims=hp.voc_fc_dims,
                    bits=hp.bits,
                    pad=hp.voc_pad,
                    upsample_factors=hp.voc_upsample_factors,
                    feat_dims=hp.num_mels,
                    compute_dims=hp.voc_compute_dims,
                    res_out_dims=hp.voc_res_out_dims,
                    res_blocks=hp.voc_res_blocks,
                    hop_length=hp.hop_length,
                    sample_rate=hp.sample_rate,
                    mode=hp.voc_mode).to(device)

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights

    model.load(voc_weights)

    simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
                  ('Target Samples', target if batched else 'N/A'),
                  ('Overlap Samples', overlap if batched else 'N/A')])
Пример #9
0
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 multigpu=False,
                 lambda_wavenet=0.001,
                 args=None,
                 residual=False,
                 attention_map=None,
                 use_256=False,
                 loss_content=False,
                 test_path=None):
        super(VideoAudioGenerator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)

        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)
        if use_256:
            self.video_decoder = VideoGenerator(use_256=True)
        else:
            self.video_decoder = STAGE2_G(residual=residual)
        self.use_256 = use_256
        self.lambda_wavenet = lambda_wavenet
        self.loss_content = loss_content
        self.multigpu = multigpu
        self.test_path = test_path

        self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims,
                               fc_dims=hparams.voc_fc_dims,
                               bits=hparams.bits,
                               pad=hparams.voc_pad,
                               upsample_factors=hparams.voc_upsample_factors,
                               feat_dims=hparams.num_mels,
                               compute_dims=hparams.voc_compute_dims,
                               res_out_dims=hparams.voc_res_out_dims,
                               res_blocks=hparams.voc_res_blocks,
                               hop_length=hparams.hop_size,
                               sample_rate=hparams.sample_rate,
                               mode=hparams.voc_mode)

        if is_train:
            self.criterionIdt = torch.nn.L1Loss(reduction='mean')
            self.opt_encoder = torch.optim.Adam(self.encoder.parameters(),
                                                lr=lr)
            self.opt_decoder = torch.optim.Adam(itertools.chain(
                self.decoder.parameters(), self.postnet.parameters()),
                                                lr=lr)
            self.opt_video_decoder = torch.optim.Adam(
                self.video_decoder.parameters(), lr=lr)

            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.video_decoder = nn.DataParallel(self.video_decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
Пример #10
0
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 loss_content=True,
                 discriminator=False,
                 multigpu=False,
                 lambda_gan=0.0001,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None):
        super(Generator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)
        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)

        if discriminator:
            self.dis = PatchDiscriminator(n_class=num_speakers)
            self.dis_criterion = GANLoss(use_lsgan=use_lsgan,
                                         tensor=torch.cuda.FloatTensor)
        else:
            self.dis = None

        self.loss_content = loss_content
        self.lambda_gan = lambda_gan
        self.lambda_wavenet = lambda_wavenet

        self.multigpu = multigpu
        self.prepare_test(dim_spec, test_path_source, test_path_target)

        self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims,
                               fc_dims=hparams.voc_fc_dims,
                               bits=hparams.bits,
                               pad=hparams.voc_pad,
                               upsample_factors=hparams.voc_upsample_factors,
                               feat_dims=hparams.num_mels,
                               compute_dims=hparams.voc_compute_dims,
                               res_out_dims=hparams.voc_res_out_dims,
                               res_blocks=hparams.voc_res_blocks,
                               hop_length=hparams.hop_size,
                               sample_rate=hparams.sample_rate,
                               mode=hparams.voc_mode)

        if is_train:
            self.criterionIdt = torch.nn.L1Loss(reduction='mean')
            self.opt_encoder = torch.optim.Adam(self.encoder.parameters(),
                                                lr=lr)
            self.opt_decoder = torch.optim.Adam(itertools.chain(
                self.decoder.parameters(), self.postnet.parameters()),
                                                lr=lr)
            if discriminator:
                self.opt_dis = torch.optim.Adam(self.dis.parameters(), lr=lr)
            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
            if self.dis is not None:
                self.dis = nn.DataParallel(self.dis)
Пример #11
0
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 decoder_type='simple',
                 vocoder_type='wavenet',
                 encoder_type='default',
                 separate_encoder=True,
                 loss_content=True,
                 discriminator=False,
                 dis_type='patch',
                 multigpu=False,
                 cycle=False,
                 lambda_cycle=1,
                 num_speakers=-1,
                 idt_type='L2',
                 use_lsgan=True,
                 lambda_gan=0.0001,
                 train_wavenet=False,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None,
                 attention=False,
                 residual=False):
        super(Generator, self).__init__()

        if encoder_type == 'default':
            self.encoder = Encoder(dim_neck, dim_emb, freq)
        elif encoder_type == 'nospeaker' or encoder_type == 'single':
            self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)
        elif encoder_type == 'multiencoder':
            self.encoder = MultiEncoder(num_speakers, dim_neck, freq,
                                        separate_encoder)

        if encoder_type == 'multiencoder' or encoder_type == 'single':
            self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        elif decoder_type == 'simple':
            self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
        elif decoder_type == 'tacotron':
            self.decoder = TacotronDecoder(hparams)
        elif decoder_type == 'multidecoder':
            self.decoder = MultiDecoder(num_speakers, dim_neck, dim_pre,
                                        multigpu)
        elif decoder_type == 'video':
            # self.decoder = VideoGenerator()
            self.decoder = STAGE2_G(residual=residual)
        self.postnet = Postnet(num_mel=dim_spec)
        if discriminator:
            if dis_type == 'patch':
                self.dis = PatchDiscriminator(n_class=num_speakers)
            else:
                self.dis = SpeakerDiscriminator()
            # self.dis_criterion = nn.CrossEntropyLoss(reduction='mean')
            self.dis_criterion = GANLoss(use_lsgan=use_lsgan,
                                         tensor=torch.cuda.FloatTensor)
        else:
            self.dis = None

        self.encoder_type = encoder_type
        self.decoder_type = decoder_type
        self.vocoder_type = vocoder_type
        self.loss_content = loss_content
        self.cycle = cycle
        self.lambda_cycle = lambda_cycle
        self.lambda_gan = lambda_gan
        self.lambda_wavenet = lambda_wavenet
        self.attention = attention

        self.multigpu = multigpu

        self.train_vocoder = train_wavenet
        if self.train_vocoder:
            self.vocoder = WaveRNN(
                rnn_dims=hparams.voc_rnn_dims,
                fc_dims=hparams.voc_fc_dims,
                bits=hparams.bits,
                pad=hparams.voc_pad,
                upsample_factors=hparams.voc_upsample_factors,
                feat_dims=hparams.num_mels,
                compute_dims=hparams.voc_compute_dims,
                res_out_dims=hparams.voc_res_out_dims,
                res_blocks=hparams.voc_res_blocks,
                hop_length=hparams.hop_size,
                sample_rate=hparams.sample_rate,
                mode=hparams.voc_mode)
            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
            if self.dis is not None:
                self.dis = nn.DataParallel(self.dis)
Пример #12
0
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 vocoder_type='wavenet',
                 multigpu=False,
                 num_speakers=-1,
                 idt_type='L2',
                 train_wavenet=False,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None,
                 residual=False,
                 attention_map=None,
                 use_256=False):
        super(VideoAudioGenerator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)

        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)
        if use_256:
            self.video_decoder = VideoGenerator(use_256=True)
        else:
            self.video_decoder = STAGE2_G(residual=residual)
        self.use_256 = use_256
        self.vocoder_type = vocoder_type
        self.lambda_wavenet = lambda_wavenet

        self.multigpu = multigpu
        # self.prepare_test(dim_spec, test_path_source, test_path_target)

        self.train_vocoder = train_wavenet
        if self.train_vocoder:
            if vocoder_type == 'wavenet' or vocoder_type == 'griffin':
                self.vocoder = WaveRNN(
                    rnn_dims=hparams.voc_rnn_dims,
                    fc_dims=hparams.voc_fc_dims,
                    bits=hparams.bits,
                    pad=hparams.voc_pad,
                    upsample_factors=hparams.voc_upsample_factors,
                    feat_dims=hparams.num_mels,
                    compute_dims=hparams.voc_compute_dims,
                    res_out_dims=hparams.voc_res_out_dims,
                    res_blocks=hparams.voc_res_blocks,
                    hop_length=hparams.hop_size,
                    sample_rate=hparams.sample_rate,
                    mode=hparams.voc_mode)
                self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                    lr=hparams.voc_lr)
                self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if attention_map is not None:
            self.attention_map_large = np.load(attention_map)
            self.attention_map = cv2.resize(self.attention_map_large,
                                            dsize=(128, 128),
                                            interpolation=cv2.INTER_CUBIC)
            # self.attention_map_large = self.attention_map_large.astype(np.float64)
            # self.attention_map = self.attention_map.astype(np.float64)
            self.attention_map_large = torch.from_numpy(
                self.attention_map_large /
                self.attention_map_large.max()).float()
            self.attention_map = torch.from_numpy(
                self.attention_map / self.attention_map.max()).float()
            self.criterionVideo = torch.nn.L1Loss(reduction='none')
        else:
            self.attention_map = None

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.video_decoder = nn.DataParallel(self.video_decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
Пример #13
0
def train(run_id: str, models_dir: Path, metadata_path: Path,
          weights_path: Path, ground_truth: bool, save_every: int,
          backup_every: int, force_restart: bool):
    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    # Instantiate the model
    print("Initializing the model...")
    model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                    fc_dims=hp.voc_fc_dims,
                    bits=hp.bits,
                    pad=hp.voc_pad,
                    upsample_factors=hp.voc_upsample_factors,
                    feat_dims=hp.num_mels,
                    compute_dims=hp.voc_compute_dims,
                    res_out_dims=hp.voc_res_out_dims,
                    res_blocks=hp.voc_res_blocks,
                    hop_length=hp.hop_length,
                    sample_rate=hp.sample_rate,
                    mode=hp.voc_mode).cuda()

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())
    for p in optimizer.param_groups:
        p["lr"] = hp.voc_lr
    loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss

    # Load the weights
    model_dir = models_dir.joinpath(run_id)
    model_dir.mkdir(exist_ok=True)
    weights_fpath = weights_path
    metadata_fpath = metadata_path

    if force_restart:
        print("\nStarting the training of WaveRNN from scratch\n")
        model.save(weights_fpath, optimizer)
    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(weights_fpath, optimizer)
        print("WaveRNN weights loaded from step %d" % model.step)

    # Initialize the dataset

    dataset = VocoderDataset(metadata_fpath)

    test_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True)

    # Begin the training
    simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr),
                  ('Sequence Len', hp.voc_seq_len)])

    epoch_start = int(
        (model.step - 428000) * 110 / dataset.get_number_of_samples())
    epoch_end = 200

    log_path = os.path.join(models_dir, "logs")
    if not os.path.isdir(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_path)
    print("Log path : " + log_path)

    print("Starting from epoch: " + str(epoch_start))

    for epoch in range(epoch_start, epoch_start + epoch_end):
        data_loader = DataLoader(dataset,
                                 collate_fn=collate_vocoder,
                                 batch_size=hp.voc_batch_size,
                                 num_workers=2,
                                 shuffle=True,
                                 pin_memory=True)
        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(data_loader, 1):
            x, m, y = x.cuda(), m.cuda(), y.cuda()

            # Forward pass
            y_hat = model(x, m)
            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y = y.float()
            y = y.unsqueeze(-1)

            # Backward pass
            loss = loss_func(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            speed = i / (time.time() - start)
            avg_loss = running_loss / i

            step = model.get_step()
            k = step // 1000

            if backup_every != 0 and step % backup_every == 0:
                model.checkpoint(model_dir, optimizer)

            # if save_every != 0 and step % save_every == 0 :
            #     model.save(weights_fpath, optimizer)

            if step % 500 == 0:
                writer.add_scalar('Loss/train', avg_loss,
                                  round(step / 1000, 1))
                msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
                    f"Loss: {avg_loss:.4f} | {speed:.1f} " \
                    f"steps/s | Step: {k}k | "
                print(msg, flush=True)

            if step % 15000 == 0:
                gen_testset(model, test_loader, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            model_dir)
                gen_meltest(model, hp.voc_gen_batched, hp.voc_target,
                            hp.voc_overlap, model_dir)