Ejemplo n.º 1
0
 def train(self,
           model: WaveRNN,
           optimizer: Optimizer,
           train_gta=False) -> None:
     voc_schedule = self.train_cfg['schedule']
     voc_schedule = parse_schedule(voc_schedule)
     for i, session_params in enumerate(voc_schedule, 1):
         lr, max_step, bs = session_params
         if model.get_step() < max_step:
             train_set, val_set, val_set_samples = get_vocoder_datasets(
                 path=self.paths.data,
                 batch_size=bs,
                 train_gta=train_gta,
                 max_mel_len=self.train_cfg['max_mel_len'],
                 hop_length=self.dsp.hop_length,
                 voc_pad=model.pad,
                 voc_seq_len=self.train_cfg['seq_len'],
                 voc_mode=self.dsp.voc_mode,
                 bits=self.dsp.bits,
                 num_gen_samples=self.train_cfg['num_gen_samples'])
             session = VocSession(index=i,
                                  lr=lr,
                                  max_step=max_step,
                                  bs=bs,
                                  train_set=train_set,
                                  val_set=val_set,
                                  val_set_samples=val_set_samples)
             self.train_session(model, optimizer, session, train_gta)
Ejemplo n.º 2
0
 def train(self, model: WaveRNN, optimizer: Optimizer, train_gta=False) -> None:
     for i, session_params in enumerate(hp.voc_schedule, 1):
         lr, max_step, bs = session_params
         if model.get_step() < max_step:
             train_set, val_set, val_set_samples = get_vocoder_datasets(
                 path=self.paths.data, batch_size=bs, train_gta=train_gta)
             session = VocSession(
                 index=i, lr=lr, max_step=max_step,
                 bs=bs, train_set=train_set, val_set=val_set,
                 val_set_samples=val_set_samples)
             self.train_session(model, optimizer, session, train_gta)
Ejemplo n.º 3
0
def main():

    # Parse Arguments
    parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
    parser.add_argument('--lr',
                        '-l',
                        type=float,
                        help='[float] override hparams.py learning rate')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        help='[int] override hparams.py batch size')
    parser.add_argument('--force_train',
                        '-f',
                        action='store_true',
                        help='Forces the model to train past total steps')
    parser.add_argument('--gta',
                        '-g',
                        action='store_true',
                        help='train wavernn on GTA features')
    parser.add_argument(
        '--force_cpu',
        '-c',
        action='store_true',
        help='Forces CPU-only training, even when in CUDA capable environment')
    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')
    args = parser.parse_args()

    hp.configure(args.hp_file)  # load hparams from file
    if args.lr is None:
        args.lr = hp.voc_lr
    if args.batch_size is None:
        args.batch_size = hp.voc_batch_size

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    batch_size = args.batch_size
    force_train = args.force_train
    train_gta = args.gta
    lr = args.lr

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
        if batch_size % torch.cuda.device_count() != 0:
            raise ValueError(
                '`batch_size` must be evenly divisible by n_gpus!')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    print('\nInitialising Model...\n')

    # Instantiate WaveRNN Model
    voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                        fc_dims=hp.voc_fc_dims,
                        bits=hp.bits,
                        pad=hp.voc_pad,
                        upsample_factors=hp.voc_upsample_factors,
                        feat_dims=hp.num_mels,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).to(device)

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    optimizer = optim.Adam(voc_model.parameters())
    restore_checkpoint('voc',
                       paths,
                       voc_model,
                       optimizer,
                       create_if_missing=True)

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size,
                                               train_gta)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

    voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set,
                   lr, total_steps)

    print('Training Complete.')
    print(
        'To continue training increase voc_total_steps in hparams.py or use --force_train'
    )
Ejemplo n.º 4
0
                  fc_dims=hp.voc_fc_dims,
                  bits=hp.bits,
                  pad=hp.voc_pad,
                  upsample_factors=hp.voc_upsample_factors,
                  feat_dims=hp.num_mels,
                  compute_dims=hp.voc_compute_dims,
                  res_out_dims=hp.voc_res_out_dims,
                  res_blocks=hp.voc_res_blocks,
                  hop_length=hp.hop_length,
                  sample_rate=hp.sample_rate).cuda()

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    restore_path = args.weights if args.weights else paths.voc_latest_weights

    model.restore(restore_path)

    simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
                  ('Target Samples', target if batched else 'N/A'),
                  ('Overlap Samples', overlap if batched else 'N/A')])

    _, test_set = get_vocoder_datasets(paths.data, 1, gta)

    if file:
        gen_from_file(model, file, paths.voc_output, batched, target, overlap)
    else:
        gen_testset(model, test_set, samples, batched, target, overlap,
                    paths.voc_output)

    print('\n\nExiting...\n')
Ejemplo n.º 5
0
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).cuda()

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    voc_model.restore(paths.voc_latest_weights)

    optimiser = optim.Adam(voc_model.parameters())

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size,
                                               train_gta)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

    voc_train_loop(voc_model, loss_func, optimiser, train_set, test_set, lr,
                   total_steps)
Ejemplo n.º 6
0
            dtr = json.load(dtr_cfg)
            trans = config_distortions(**dtr)
            print(trans)
    else:
        trans = None

    if hasattr(hp, 'spk2split'):
        with open(hp.spk2split, 'r') as f:
            spk2split = json.load(f)
    else:
        spk2split = None

    train_set, test_set, valid_set = get_vocoder_datasets(
        paths.data,
        batch_size,
        train_gta,
        num_workers=num_workers,
        transforms=trans,
        spk2split=spk2split)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    writer = SummaryWriter(paths.voc_checkpoints)
    hp.writer = writer