def main(args): # MODEL num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \ [args.features*2**i for i in range(0, args.levels)] target_outputs = int(args.output_size * args.sr) model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, target_output_size=target_outputs, depth=args.depth, strides=args.strides, conv_type=args.conv_type, res=args.res, separate=args.separate) if args.cuda: model = model_utils.DataParallel(model) print("move model to gpu") model.cuda() print("Loading model from checkpoint " + str(args.load_model)) state = model_utils.load_model(model, None, args.load_model, args.cuda) print('Step', state['step']) preds = predict_song(args, args.input, model) output_folder = os.path.dirname( args.input) if args.output is None else args.output for inst in preds.keys(): data.utils.write_wav( os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst].T, args.sr)
def setup(self): """Init wave u net model""" parser = argparse.ArgumentParser() parser.add_argument( "--instruments", type=str, nargs="+", default=["bass", "drums", "other", "vocals"], help= 'List of instruments to separate (default: "bass drums other vocals")', ) parser.add_argument("--cuda", action="store_true", help="Use CUDA (default: False)") parser.add_argument( "--features", type=int, default=32, help="Number of feature channels per layer", ) parser.add_argument( "--load_model", type=str, default="checkpoints/waveunet/model", help="Reload a previously trained model", ) parser.add_argument("--batch_size", type=int, default=4, help="Batch size") parser.add_argument("--levels", type=int, default=6, help="Number of DS/US blocks") parser.add_argument("--depth", type=int, default=1, help="Number of convs per block") parser.add_argument("--sr", type=int, default=44100, help="Sampling rate") parser.add_argument("--channels", type=int, default=2, help="Number of input audio channels") parser.add_argument( "--kernel_size", type=int, default=5, help="Filter width of kernels. Has to be an odd number", ) parser.add_argument("--output_size", type=float, default=2.0, help="Output duration") parser.add_argument("--strides", type=int, default=4, help="Strides in Waveunet") parser.add_argument( "--conv_type", type=str, default="gn", help= "Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn", ) parser.add_argument( "--res", type=str, default="fixed", help= "Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned", ) parser.add_argument( "--separate", type=int, default=1, help="Train separate model for each source (1) or only one (0)", ) parser.add_argument( "--feature_growth", type=str, default="double", help= "How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)", ) """ parser.add_argument('--input', type=str, default=str(input), help="Path to input mixture to be separated") parser.add_argument('--output', type=str, default=out_path, help="Output path (same folder as input path if not set)") """ args = parser.parse_args([]) self.args = args num_features = ([args.features * i for i in range(1, args.levels + 1)] if args.feature_growth == "add" else [args.features * 2**i for i in range(0, args.levels)]) target_outputs = int(args.output_size * args.sr) self.model = Waveunet( args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, target_output_size=target_outputs, depth=args.depth, strides=args.strides, conv_type=args.conv_type, res=args.res, separate=args.separate, ) if args.cuda: self.model = model_utils.DataParallel(model) print("move model to gpu") self.model.cuda() print("Loading model from checkpoint " + str(args.load_model)) state = model_utils.load_model(self.model, None, args.load_model, args.cuda) print("Step", state["step"])
def main(args): #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5 # MODEL num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \ [args.features*2**i for i in range(0, args.levels)] target_outputs = int(args.output_size * args.sr) model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, target_output_size=target_outputs, depth=args.depth, strides=args.strides, conv_type=args.conv_type, res=args.res, separate=args.separate) if args.cuda: model = model_utils.DataParallel(model) print("move model to gpu") model.cuda() print('model: ', model) print('parameter count: ', str(sum(p.numel() for p in model.parameters()))) writer = SummaryWriter(args.log_dir) ### DATASET musdb = get_musdb_folds(args.dataset_dir) # If not data augmentation, at least crop targets to fit model output shape crop_func = partial(crop_targets, shapes=model.shapes) # Data augmentation function for training augment_func = partial(random_amplify, shapes=model.shapes, min=0.7, max=1.0) train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, model.shapes, True, args.hdf_dir, audio_transform=augment_func) val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func) test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func) dataloader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn) ##### TRAINING #### # Set up the loss function if args.loss == "L1": criterion = nn.L1Loss() elif args.loss == "L2": criterion = nn.MSELoss() else: raise NotImplementedError("Couldn't find this loss!") # Set up optimiser optimizer = Adam(params=model.parameters(), lr=args.lr) # Set up training state dict that will also be saved into checkpoints state = {"step": 0, "worse_epochs": 0, "epochs": 0, "best_loss": np.Inf} # LOAD MODEL CHECKPOINT IF DESIRED if args.load_model is not None: print("Continuing training full model from checkpoint " + str(args.load_model)) state = model_utils.load_model(model, optimizer, args.load_model, args.cuda) print('TRAINING START') while state["worse_epochs"] < args.patience: print("Training one epoch from iteration " + str(state["step"])) avg_time = 0. model.train() with tqdm(total=len(train_data) // args.batch_size) as pbar: np.random.seed() for example_num, (x, targets) in enumerate(dataloader): if args.cuda: x = x.cuda() for k in list(targets.keys()): targets[k] = targets[k].cuda() t = time.time() # Set LR for this iteration utils.set_cyclic_lr(optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr) writer.add_scalar("lr", utils.get_lr(optimizer), state["step"]) # Compute loss for each instrument/model optimizer.zero_grad() outputs, avg_loss = model_utils.compute_loss(model, x, targets, criterion, compute_grad=True) optimizer.step() state["step"] += 1 t = time.time() - t avg_time += (1. / float(example_num + 1)) * (t - avg_time) writer.add_scalar("train_loss", avg_loss, state["step"]) if example_num % args.example_freq == 0: input_centre = torch.mean( x[0, :, model.shapes["output_start_frame"]:model. shapes["output_end_frame"]], 0) # Stereo not supported for logs yet writer.add_audio("input", input_centre, state["step"], sample_rate=args.sr) for inst in outputs.keys(): writer.add_audio(inst + "_pred", torch.mean(outputs[inst][0], 0), state["step"], sample_rate=args.sr) writer.add_audio(inst + "_target", torch.mean(targets[inst][0], 0), state["step"], sample_rate=args.sr) pbar.update(1) # VALIDATE val_loss = validate(args, model, criterion, val_data) print("VALIDATION FINISHED: LOSS: " + str(val_loss)) writer.add_scalar("val_loss", val_loss, state["step"]) # EARLY STOPPING CHECK checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint_" + str(state["step"])) if val_loss >= state["best_loss"]: state["worse_epochs"] += 1 else: print("MODEL IMPROVED ON VALIDATION SET!") state["worse_epochs"] = 0 state["best_loss"] = val_loss state["best_checkpoint"] = checkpoint_path # CHECKPOINT print("Saving model...") model_utils.save_model(model, optimizer, state, checkpoint_path) state["epochs"] += 1 #### TESTING #### # Test loss print("TESTING") # Load best model based on validation loss state = model_utils.load_model(model, None, state["best_checkpoint"], args.cuda) test_loss = validate(args, model, criterion, test_data) print("TEST FINISHED: LOSS: " + str(test_loss)) writer.add_scalar("test_loss", test_loss, state["step"]) # Mir_eval metrics test_metrics = evaluate(args, musdb["test"], model, args.instruments) # Dump all metrics results into pickle file for later analysis if needed with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f: pickle.dump(test_metrics, f) # Write most important metrics into Tensorboard log avg_SDRs = { inst: np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments } avg_SIRs = { inst: np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments } for inst in args.instruments: writer.add_scalar("test_SDR_" + inst, avg_SDRs[inst], state["step"]) writer.add_scalar("test_SIR_" + inst, avg_SIRs[inst], state["step"]) overall_SDR = np.mean([v for v in avg_SDRs.values()]) writer.add_scalar("test_SDR", overall_SDR) print("SDR: " + str(overall_SDR)) writer.close()