def get_datasets(args): text_preprocessor = partial( text_to_sequence, symbol_list=args.text_preprocessor, phonemizer=args.phonemizer, checkpoint=args.phonemizer_checkpoint, cmudict_root=args.cmudict_root, ) transforms = torch.nn.Sequential( torchaudio.transforms.MelSpectrogram( sample_rate=args.sample_rate, n_fft=args.n_fft, win_length=args.win_length, hop_length=args.hop_length, f_min=args.mel_fmin, f_max=args.mel_fmax, n_mels=args.n_mels, mel_scale='slaney', normalized=False, power=1, norm='slaney', ), SpectralNormalization()) trainset, valset = split_process_dataset(args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor) return trainset, valset
def main(args): devices = ["cuda" if torch.cuda.is_available() else "cpu"] logging.info("Start time: {}".format(str(datetime.now()))) melkwargs = { "n_fft": args.n_fft, "power": 1, "hop_length": args.hop_length, "win_length": args.win_length, } transforms = torch.nn.Sequential( torchaudio.transforms.MelSpectrogram( sample_rate=args.sample_rate, n_mels=args.n_freq, f_min=args.f_min, mel_scale='slaney', norm='slaney', **melkwargs, ), NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization), ) train_dataset, val_dataset = split_process_dataset(args, transforms) loader_training_params = { "num_workers": args.workers, "pin_memory": False, "shuffle": True, "drop_last": False, } loader_validation_params = loader_training_params.copy() loader_validation_params["shuffle"] = False collate_fn = collate_factory(args) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_training_params, ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_validation_params, ) n_classes = 2**args.n_bits if args.loss == "crossentropy" else 30 model = WaveRNN( upsample_scales=args.upsample_scales, n_classes=n_classes, hop_length=args.hop_length, n_res_block=args.n_res_block, n_rnn=args.n_rnn, n_fc=args.n_fc, kernel_size=args.kernel_size, n_freq=args.n_freq, n_hidden=args.n_hidden_melresnet, n_output=args.n_output_melresnet, ) if args.jit: model = torch.jit.script(model) model = torch.nn.DataParallel(model) model = model.to(devices[0], non_blocking=True) n = count_parameters(model) logging.info(f"Number of parameters: {n}") # Optimizer optimizer_params = { "lr": args.learning_rate, } optimizer = Adam(model.parameters(), **optimizer_params) criterion = LongCrossEntropyLoss( ) if args.loss == "crossentropy" else MoLLoss() best_loss = 10.0 if args.checkpoint and os.path.isfile(args.checkpoint): logging.info(f"Checkpoint: loading '{args.checkpoint}'") checkpoint = torch.load(args.checkpoint) args.start_epoch = checkpoint["epoch"] best_loss = checkpoint["best_loss"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) logging.info( f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}" ) else: logging.info("Checkpoint: not found") save_checkpoint( { "epoch": args.start_epoch, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), }, False, args.checkpoint, ) for epoch in range(args.start_epoch, args.epochs): train_one_epoch( model, criterion, optimizer, train_loader, devices[0], epoch, ) if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: sum_loss = validate(model, criterion, val_loader, devices[0], epoch) is_best = sum_loss < best_loss best_loss = min(sum_loss, best_loss) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), }, is_best, args.checkpoint, ) logging.info(f"End time: {datetime.now()}")