def inference_tta(self, X_spec): X_mag, X_phase = self.preprocess(X_spec) coef = X_mag.max() X_mag_pre = X_mag / coef n_frame = X_mag_pre.shape[2] pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.window_size, self.offset) n_window = int(np.ceil(n_frame / roi_size)) X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') pred = self._execute(X_mag_pad, roi_size, n_window) pred = pred[:, :, :n_frame] pad_l += roi_size // 2 pad_r += roi_size // 2 n_window += 1 X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') pred_tta = self._execute(X_mag_pad, roi_size, n_window) pred_tta = pred_tta[:, :, roi_size // 2:] pred_tta = pred_tta[:, :, :n_frame] return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase)
def main(): p = argparse.ArgumentParser() p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument( '--model', '-m', type=str, default= '/content/drive/My Drive/vocal-remover/models/MultiGenreModelNP.pth') p.add_argument('--input', '-i', required=True) p.add_argument('--sr', '-r', type=int, default=44100) p.add_argument('--hop_length', '-l', type=int, default=1024) p.add_argument('--window_size', '-w', type=int, default=512) p.add_argument('--out_mask', '-M', action='store_true') p.add_argument('--postprocess', '-p', action='store_true') args = p.parse_args() print('loading model...', end=' ') device = torch.device('cpu') model = nets.CascadedASPPNet() model.load_state_dict(torch.load(args.model, map_location=device)) if torch.cuda.is_available() and args.gpu >= 0: device = torch.device('cuda:{}'.format(args.gpu)) model.to(device) print('done') print('loading wave source...', end=' ') X, sr = librosa.load(args.input, args.sr, False, dtype=np.float32, res_type='kaiser_fast') print('done') print('stft of wave source...', end=' ') X = spec_utils.calc_spec(X, args.hop_length) X, phase = np.abs(X), np.exp(1.j * np.angle(X)) coeff = X.max() X /= coeff print('done') offset = model.offset l, r, roi_size = dataset.make_padding(X.shape[2], args.window_size, offset) X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant') X_roll = np.roll(X_pad, roi_size // 2, axis=2) model.eval() with torch.no_grad(): masks = [] masks_roll = [] for i in tqdm(range(int(np.ceil(X.shape[2] / roi_size)))): start = i * roi_size X_window = torch.from_numpy( np.asarray([ X_pad[:, :, start:start + args.window_size], X_roll[:, :, start:start + args.window_size] ])).to(device) pred = model.predict(X_window) pred = pred.detach().cpu().numpy() masks.append(pred[0]) masks_roll.append(pred[1]) mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]] mask_roll = np.concatenate(masks_roll, axis=2)[:, :, :X.shape[2]] mask = (mask + np.roll(mask_roll, -roi_size // 2, axis=2)) / 2 if args.postprocess: vocal = X * (1 - mask) * coeff mask = spec_utils.mask_uninformative(mask, vocal) inst = X * mask * coeff vocal = X * (1 - mask) * coeff basename = os.path.splitext(os.path.basename(args.input))[0] print('inverse stft of instruments...', end=' ') wav = spec_utils.spec_to_wav(inst, phase, args.hop_length) print('done') sf.write('{}_Instruments.wav'.format(basename), wav.T, sr) print('inverse stft of vocals...', end=' ') wav = spec_utils.spec_to_wav(vocal, phase, args.hop_length) print('done') sf.write('{}_Vocals.wav'.format(basename), wav.T, sr) if args.out_mask: norm_mask = np.uint8((1 - mask) * 255).transpose(1, 2, 0) norm_mask = np.concatenate( [np.max(norm_mask, axis=2, keepdims=True), norm_mask], axis=2)[::-1] _, bin_mask = cv2.imencode('.png', norm_mask) with open('{}_Mask.png'.format(basename), mode='wb') as f: bin_mask.tofile(f)
def main(): p = argparse.ArgumentParser() p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--model', '-m', type=str, default='models/baseline.pth') p.add_argument('--input', '-i', required=True) p.add_argument('--sr', '-r', type=int, default=44100) p.add_argument('--hop_length', '-l', type=int, default=1024) p.add_argument('--window_size', '-w', type=int, default=512) p.add_argument('--out_mask', '-M', action='store_true') p.add_argument('--postprocess', '-p', action='store_true') args = p.parse_args() print('loading model...', end=' ') device = torch.device('cpu') model = nets.CascadedASPPNet() model.load_state_dict(torch.load(args.model, map_location=device)) if torch.cuda.is_available() and args.gpu >= 0: device = torch.device('cuda:{}'.format(args.gpu)) model.to(device) print('done') print('loading wave source...', end=' ') X, sr = librosa.load(args.input, args.sr, False, dtype=np.float32, res_type='kaiser_fast') print('done') print('wave source stft...', end=' ') X = spec_utils.calc_spec(X, args.hop_length) X, phase = np.abs(X), np.exp(1.j * np.angle(X)) coeff = X.max() X /= coeff print('done') offset = model.offset l, r, roi_size = dataset.make_padding(X.shape[2], args.window_size, offset) X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant') masks = [] model.eval() with torch.no_grad(): for j in tqdm(range(int(np.ceil(X.shape[2] / roi_size)))): start = j * roi_size X_window = X_pad[None, :, :, start:start + args.window_size] pred = model.predict(torch.from_numpy(X_window).to(device)) pred = pred.detach().cpu().numpy() masks.append(pred[0]) mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]] if args.postprocess: vocal_pred = X * (1 - mask) * coeff mask = spec_utils.mask_uninformative(mask, vocal_pred) inst_pred = X * mask * coeff vocal_pred = X * (1 - mask) * coeff if args.out_mask: norm_mask = np.uint8((1 - mask) * 255) canvas = np.zeros((norm_mask.shape[1], norm_mask.shape[2], 3)) canvas[:, :, 1] = norm_mask[0] canvas[:, :, 2] = norm_mask[1] canvas[:, :, 0] = np.max(norm_mask, axis=0) cv2.imwrite('mask.png', canvas[::-1]) basename = os.path.splitext(os.path.basename(args.input))[0] print('instrumental inverse stft...', end=' ') wav = spec_utils.spec_to_wav(inst_pred, phase, args.hop_length) print('done') sf.write('{}_Instrumental.wav'.format(basename), wav.T, sr) print('vocal inverse stft...', end=' ') wav = spec_utils.spec_to_wav(vocal_pred, phase, args.hop_length) print('done') sf.write('{}_Vocal.wav'.format(basename), wav.T, sr)