def separate(args): if args.mix_dir is None and args.mix_json is None: print("Must provide mix_dir or mix_json! When providing mix_dir, " "mix_json is ignored.") # Load model model = ConvTasNet(256, 20, 256, 512, 3, 8, 4, 2, norm_type="gLN", causal=0, mask_nonlinear="relu") model.cuda() model.load_state_dict(torch.load(args.model_path)['sep_state_dict']) print(model) model.eval() # Load data eval_dataset = EvalDataset(args.mix_dir, args.mix_json, batch_size=args.batch_size, sample_rate=args.sample_rate) eval_loader = EvalDataLoader(eval_dataset, batch_size=1) os.makedirs(args.out_dir, exist_ok=True) def write(inputs, filename, sr=args.sample_rate): #librosa.output.write_wav(filename, inputs, sr)# norm=True) #librosa.output.write_wav(filename, inputs, sr, norm=True) #print(inputs) inputs = inputs / max(np.abs(inputs)) #print(inputs) sf.write(filename, inputs, sr) #sf.write(filename, inputs, sr, 'PCM_16') with torch.no_grad(): for (i, data) in enumerate(eval_loader): # Get batch data mixture, mix_lengths, filenames = data if args.use_cuda: mixture, mix_lengths = mixture.cuda(), mix_lengths.cuda() # Forward estimate_source = model(mixture) # [B, C, T] # Remove padding and flat flat_estimate = remove_pad(estimate_source, mix_lengths) mixture = remove_pad(mixture, mix_lengths) # Write result for i, filename in enumerate(filenames): filename = os.path.join( args.out_dir, os.path.basename(filename).strip('.wav')) write(mixture[i], filename + '.wav') C = flat_estimate[i].shape[0] for c in range(C): write(flat_estimate[i][c], filename + '_s{}.wav'.format(c + 1))
'causal': args.causal, 'mask_nonlinear': args.mask_nonlinear } train_args = { 'lr': args.lr, 'batch_size': args.batch_size, 'epochs': args.epochs } model = ConvTasNet(**model_args) if args.evaluate == 0 and args.separate == 0: dataset = AudioDataset(args.data_dir, sr=args.sr, mode='train', seq_len=args.seq_len, verbose=0, voice_only=args.voice_only) print('DataLoading Done') train(model, dataset, **train_args) elif args.evaluate == 1: model.load_state_dict(torch.load(args.model, map_location='cpu')) dataset = AudioDataset(args.data_dir, sr=args.sr, mode='test', seq_len=args.seq_len, verbose=0, voice_only=args.voice_only) evaluate(model, dataset, args.batch_size, 0, args.cal_sdr) else: model.load_state_dict(torch.load(args.model, map_location='cpu')) dataset = AudioDataset(args.data_dir, sr=args.sr, mode='test', seq_len=args.seq_len, verbose=0, voice_only=args.voice_only) separate(model, dataset, args.output_dir, sr=8000)