def main(rank, args): # Distributed setup if args.distributed: setup_distributed(rank, args.world_size) not_main_rank = args.distributed and rank != 0 logging.info("Start time: %s", datetime.now()) # Explicitly set seed to make sure models created in separate processes # start from same random weights and biases torch.manual_seed(args.seed) # Empty CUDA cache torch.cuda.empty_cache() # Change backend for flac files torchaudio.set_audio_backend("soundfile") # Transforms melkwargs = { "n_fft": args.win_length, "n_mels": args.n_bins, "hop_length": args.hop_length, } sample_rate_original = 16000 if args.type == "mfcc": transforms = torch.nn.Sequential( torchaudio.transforms.MFCC( sample_rate=sample_rate_original, n_mfcc=args.n_bins, melkwargs=melkwargs, ), ) num_features = args.n_bins elif args.type == "waveform": transforms = torch.nn.Sequential(UnsqueezeFirst()) num_features = 1 else: raise ValueError("Model type not supported") if args.normalize: transforms = torch.nn.Sequential(transforms, Normalize()) augmentations = torch.nn.Sequential() if args.freq_mask: augmentations = torch.nn.Sequential( augmentations, torchaudio.transforms.FrequencyMasking( freq_mask_param=args.freq_mask), ) if args.time_mask: augmentations = torch.nn.Sequential( augmentations, torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask), ) # Text preprocessing char_blank = "*" char_space = " " char_apostrophe = "'" labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase language_model = LanguageModel(labels, char_blank, char_space) # Dataset training, validation = split_process_librispeech( [args.dataset_train, args.dataset_valid], [transforms, transforms], language_model, root=args.dataset_root, folder_in_archive=args.dataset_folder_in_archive, ) # Decoder if args.decoder == "greedy": decoder = GreedyDecoder() else: raise ValueError("Selected decoder not supported") # Model model = Wav2Letter( num_classes=language_model.length, input_type=args.type, num_features=num_features, ) if args.jit: model = torch.jit.script(model) if args.distributed: n = torch.cuda.device_count() // args.world_size devices = list(range(rank * n, (rank + 1) * n)) model = model.to(devices[0]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) else: devices = ["cuda" if torch.cuda.is_available() else "cpu"] model = model.to(devices[0], non_blocking=True) model = torch.nn.DataParallel(model) n = count_parameters(model) logging.info("Number of parameters: %s", n) # Optimizer if args.optimizer == "adadelta": optimizer = Adadelta( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, eps=args.eps, rho=args.rho, ) elif args.optimizer == "sgd": optimizer = SGD( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optimizer == "adam": optimizer = Adam( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optimizer == "adamw": optimizer = AdamW( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) else: raise ValueError("Selected optimizer not supported") if args.scheduler == "exponential": scheduler = ExponentialLR(optimizer, gamma=args.gamma) elif args.scheduler == "reduceonplateau": scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) else: raise ValueError("Selected scheduler not supported") criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank], zero_infinity=False) # Data Loader collate_fn_train = collate_factory(model_length_function, augmentations) collate_fn_valid = collate_factory(model_length_function) loader_training_params = { "num_workers": args.workers, "pin_memory": True, "shuffle": True, "drop_last": True, } loader_validation_params = loader_training_params.copy() loader_validation_params["shuffle"] = False loader_training = DataLoader( training, batch_size=args.batch_size, collate_fn=collate_fn_train, **loader_training_params, ) loader_validation = DataLoader( validation, batch_size=args.batch_size, collate_fn=collate_fn_valid, **loader_validation_params, ) # Setup checkpoint best_loss = 1.0 load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) if args.distributed: torch.distributed.barrier() if load_checkpoint: logging.info("Checkpoint: loading %s", args.checkpoint) checkpoint = torch.load(args.checkpoint) args.start_epoch = checkpoint["epoch"] best_loss = checkpoint["best_loss"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]) else: logging.info("Checkpoint: not found") save_checkpoint( { "epoch": args.start_epoch, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, False, args.checkpoint, not_main_rank, ) if args.distributed: torch.distributed.barrier() torch.autograd.set_detect_anomaly(False) for epoch in range(args.start_epoch, args.epochs): logging.info("Epoch: %s", epoch) train_one_epoch( model, criterion, optimizer, scheduler, loader_training, decoder, language_model, devices[0], epoch, args.clip_grad, not_main_rank, not args.reduce_lr_valid, ) loss = evaluate( model, criterion, loader_validation, decoder, language_model, devices[0], epoch, not_main_rank, ) if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss) is_best = loss < best_loss best_loss = min(loss, best_loss) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, is_best, args.checkpoint, not_main_rank, ) logging.info("End time: %s", datetime.now()) if args.distributed: torch.distributed.destroy_process_group()
def main(args): devices = ["cuda" if torch.cuda.is_available() else "cpu"] logging.info("Start time: {}".format(str(datetime.now()))) melkwargs = { "n_fft": args.n_fft, "power": 1, "hop_length": args.hop_length, "win_length": args.win_length, } transforms = torch.nn.Sequential( torchaudio.transforms.Spectrogram(**melkwargs), LinearToMel( sample_rate=args.sample_rate, n_fft=args.n_fft, n_mels=args.n_freq, fmin=args.f_min, ), NormalizeDB(min_level_db=args.min_level_db), ) train_dataset, val_dataset = split_process_ljspeech(args, transforms) loader_training_params = { "num_workers": args.workers, "pin_memory": False, "shuffle": True, "drop_last": False, } loader_validation_params = loader_training_params.copy() loader_validation_params["shuffle"] = False collate_fn = collate_factory(args) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_training_params, ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_validation_params, ) n_classes = 2**args.n_bits if args.loss == "crossentropy" else 30 model = WaveRNN( upsample_scales=args.upsample_scales, n_classes=n_classes, hop_length=args.hop_length, n_res_block=args.n_res_block, n_rnn=args.n_rnn, n_fc=args.n_fc, kernel_size=args.kernel_size, n_freq=args.n_freq, n_hidden=args.n_hidden_melresnet, n_output=args.n_output_melresnet, ) if args.jit: model = torch.jit.script(model) model = torch.nn.DataParallel(model) model = model.to(devices[0], non_blocking=True) n = count_parameters(model) logging.info(f"Number of parameters: {n}") # Optimizer optimizer_params = { "lr": args.learning_rate, } optimizer = Adam(model.parameters(), **optimizer_params) criterion = LongCrossEntropyLoss( ) if args.loss == "crossentropy" else MoLLoss() best_loss = 10.0 if args.checkpoint and os.path.isfile(args.checkpoint): logging.info(f"Checkpoint: loading '{args.checkpoint}'") checkpoint = torch.load(args.checkpoint) args.start_epoch = checkpoint["epoch"] best_loss = checkpoint["best_loss"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) logging.info( f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}" ) else: logging.info("Checkpoint: not found") save_checkpoint( { "epoch": args.start_epoch, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), }, False, args.checkpoint, ) for epoch in range(args.start_epoch, args.epochs): train_one_epoch( model, criterion, optimizer, train_loader, devices[0], epoch, ) if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: sum_loss = validate(model, criterion, val_loader, devices[0], epoch) is_best = sum_loss < best_loss best_loss = min(sum_loss, best_loss) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), }, is_best, args.checkpoint, ) logging.info(f"End time: {datetime.now()}")