def test_shape(spectrogram, nb_bins, nb_channels, unidirectional, hidden_size): unmix = model.OpenUnmix( nb_bins=nb_bins, nb_channels=nb_channels, unidirectional=unidirectional, nb_layers=1, # speed up training hidden_size=hidden_size, ) unmix.eval() Y = unmix(spectrogram) assert spectrogram.shape == Y.shape
def load_target_models(targets, model_str_or_path="umxhq", device="cpu", pretrained=True): """Core model loader target model path can be either <target>.pth, or <target>-sha256.pth (as used on torchub) The loader either loads the models from a known model string as registered in the __init__.py or loads from custom configs. """ if isinstance(targets, str): targets = [targets] model_path = Path(model_str_or_path).expanduser() if not model_path.exists(): # model path does not exist, use pretrained models try: # disable progress bar hub_loader = getattr(openunmix, model_str_or_path + "_spec") err = io.StringIO() with redirect_stderr(err): return hub_loader(targets=targets, device=device, pretrained=pretrained) print(err.getvalue()) except AttributeError: raise NameError("Model does not exist on torchhub") # assume model is a path to a local model_str_or_path directory else: models = {} for target in targets: # load model from disk with open(Path(model_path, target + ".json"), "r") as stream: results = json.load(stream) target_model_path = next(Path(model_path).glob("%s*.pth" % target)) state = torch.load(target_model_path, map_location=device) models[target] = model.OpenUnmix( nb_bins=results["args"]["nfft"] // 2 + 1, nb_channels=results["args"]["nb_channels"], hidden_size=results["args"]["hidden_size"], max_bin=state["input_mean"].shape[0], ) if pretrained: models[target].load_state_dict(state, strict=False) models[target].to(device) return models
def main(): parser = argparse.ArgumentParser(description="Open Unmix Trainer") # which target do we want to train? parser.add_argument( "--target", type=str, default="vocals", help="target source (will be passed to the dataset)", ) # Dataset paramaters parser.add_argument( "--dataset", type=str, default="musdb", choices=[ "musdb", "aligned", "sourcefolder", "trackfolder_var", "trackfolder_fix", ], help="Name of the dataset.", ) parser.add_argument("--root", type=str, help="root path of dataset") parser.add_argument( "--output", type=str, default="open-unmix", help="provide output path base folder name", ) parser.add_argument("--model", type=str, help="Path to checkpoint folder") parser.add_argument( "--audio-backend", type=str, default="soundfile", help="Set torchaudio backend (`sox_io` or `soundfile`", ) # Training Parameters parser.add_argument("--epochs", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--lr", type=float, default=0.001, help="learning rate, defaults to 1e-3") parser.add_argument( "--patience", type=int, default=140, help="maximum number of train epochs (default: 140)", ) parser.add_argument( "--lr-decay-patience", type=int, default=80, help="lr decay patience for plateau scheduler", ) parser.add_argument( "--lr-decay-gamma", type=float, default=0.3, help="gamma of learning rate scheduler decay", ) parser.add_argument("--weight-decay", type=float, default=0.00001, help="weight decay") parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") # Model Parameters parser.add_argument( "--seq-dur", type=float, default=6.0, help="Sequence duration in seconds" "value of <=0.0 will use full/variable length", ) parser.add_argument( "--unidirectional", action="store_true", default=False, help="Use unidirectional LSTM", ) parser.add_argument("--nfft", type=int, default=4096, help="STFT fft size and window size") parser.add_argument("--nhop", type=int, default=1024, help="STFT hop size") parser.add_argument( "--hidden-size", type=int, default=512, help="hidden size parameter of bottleneck layers", ) parser.add_argument("--bandwidth", type=int, default=16000, help="maximum model bandwidth in herz") parser.add_argument( "--nb-channels", type=int, default=2, help="set number of channels for model (1, 2)", ) parser.add_argument("--nb-workers", type=int, default=0, help="Number of workers for dataloader.") parser.add_argument( "--debug", action="store_true", default=False, help="Speed up training init for dev purposes", ) # Misc Parameters parser.add_argument( "--quiet", action="store_true", default=False, help="less verbose during training", ) parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") args, _ = parser.parse_known_args() torchaudio.set_audio_backend(args.audio_backend) use_cuda = not args.no_cuda and torch.cuda.is_available() print("Using GPU:", use_cuda) dataloader_kwargs = { "num_workers": args.nb_workers, "pin_memory": True } if use_cuda else {} repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) repo = Repo(repo_dir) commit = repo.head.commit.hexsha[:7] # use jpg or npy torch.manual_seed(args.seed) random.seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") train_dataset, valid_dataset, args = data.load_datasets(parser, args) # create output dir if not exist target_path = Path(args.output) target_path.mkdir(parents=True, exist_ok=True) train_sampler = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **dataloader_kwargs) valid_sampler = torch.utils.data.DataLoader(valid_dataset, batch_size=1, **dataloader_kwargs) stft, _ = transforms.make_filterbanks( n_fft=args.nfft, n_hop=args.nhop, sample_rate=train_dataset.sample_rate) encoder = torch.nn.Sequential( stft, model.ComplexNorm(mono=args.nb_channels == 1)).to(device) separator_conf = { "nfft": args.nfft, "nhop": args.nhop, "sample_rate": train_dataset.sample_rate, "nb_channels": args.nb_channels, } with open(Path(target_path, "separator.json"), "w") as outfile: outfile.write(json.dumps(separator_conf, indent=4, sort_keys=True)) if args.model or args.debug: scaler_mean = None scaler_std = None else: scaler_mean, scaler_std = get_statistics(args, encoder, train_dataset) max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft, args.bandwidth) unmix = model.OpenUnmix( input_mean=scaler_mean, input_scale=scaler_std, nb_bins=args.nfft // 2 + 1, nb_channels=args.nb_channels, hidden_size=args.hidden_size, max_bin=max_bin, ).to(device) optimizer = torch.optim.Adam(unmix.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.lr_decay_gamma, patience=args.lr_decay_patience, cooldown=10, ) es = utils.EarlyStopping(patience=args.patience) # if a model is specified: resume training if args.model: model_path = Path(args.model).expanduser() with open(Path(model_path, args.target + ".json"), "r") as stream: results = json.load(stream) target_model_path = Path(model_path, args.target + ".chkpnt") checkpoint = torch.load(target_model_path, map_location=device) unmix.load_state_dict(checkpoint["state_dict"], strict=False) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) # train for another epochs_trained t = tqdm.trange( results["epochs_trained"], results["epochs_trained"] + args.epochs + 1, disable=args.quiet, ) train_losses = results["train_loss_history"] valid_losses = results["valid_loss_history"] train_times = results["train_time_history"] best_epoch = results["best_epoch"] es.best = results["best_loss"] es.num_bad_epochs = results["num_bad_epochs"] # else start from 0 else: t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) train_losses = [] valid_losses = [] train_times = [] best_epoch = 0 for epoch in t: t.set_description("Training Epoch") end = time.time() train_loss = train(args, unmix, encoder, device, train_sampler, optimizer) valid_loss = valid(args, unmix, encoder, device, valid_sampler) scheduler.step(valid_loss) train_losses.append(train_loss) valid_losses.append(valid_loss) t.set_postfix(train_loss=train_loss, val_loss=valid_loss) stop = es.step(valid_loss) if valid_loss == es.best: best_epoch = epoch utils.save_checkpoint( { "epoch": epoch + 1, "state_dict": unmix.state_dict(), "best_loss": es.best, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, is_best=valid_loss == es.best, path=target_path, target=args.target, ) # save params params = { "epochs_trained": epoch, "args": vars(args), "best_loss": es.best, "best_epoch": best_epoch, "train_loss_history": train_losses, "valid_loss_history": valid_losses, "train_time_history": train_times, "num_bad_epochs": es.num_bad_epochs, "commit": commit, } with open(Path(target_path, args.target + ".json"), "w") as outfile: outfile.write(json.dumps(params, indent=4, sort_keys=True)) train_times.append(time.time() - end) if stop: print("Apply Early Stopping") break