def main(argv): model = ConvTasNet.make(get_model_param()) dataset = Dataset(FLAGS.dataset_path, max_decoded=FLAGS.max_decoded) checkpoint_dir = FLAGS.checkpoint epoch = 0 if path.exists(checkpoint_dir): checkpoints = [name for name in listdir( checkpoint_dir) if "ckpt" in name] checkpoints.sort() checkpoint_name = checkpoints[-1].split(".")[0] epoch = int(checkpoint_name) + 1 model.load_weights(f"{checkpoint_dir}/{checkpoint_name}.ckpt") epochs_to_inc = FLAGS.epochs while epochs_to_inc == None or epochs_to_inc > 0: print(f"Epoch: {epoch}") history = model.fit(dataset.make_dataset(get_dataset_param())) model.save_weights(f"{checkpoint_dir}/{epoch:05d}.ckpt") epoch += 1 if epochs_to_inc != None: epochs_to_inc -= 1 model.param.save(f"{checkpoint_dir}/config.txt") model.save(f"{checkpoint_dir}/model")
def main(argv): checkpoint_dir = FLAGS.checkpoint if not path.exists(checkpoint_dir): raise ValueError(f"'{checkpoint_dir}' does not exist") checkpoints = [name for name in listdir(checkpoint_dir) if "ckpt" in name] if not checkpoints: raise ValueError(f"No checkpoint exists") checkpoints.sort() checkpoint_name = checkpoints[-1].split(".")[0] param = ConvTasNetParam.load(f"{checkpoint_dir}/config.txt") model = ConvTasNet.make(param) model.load_weights(f"{checkpoint_dir}/{checkpoint_name}.ckpt") video_id = FLAGS.video_id ydl_opts = { "format": "bestaudio/best", "postprocessors": [{ "key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "44100", }], "outtmpl": "%(title)s.wav", "progress_hooks": [youtube_dl_hook], } with youtube_dl.YoutubeDL(ydl_opts) as ydl: info = ydl.extract_info(video_id, download=False) status = ydl.download([video_id]) title = info.get("title", None) filename = title + ".wav" audio, sr = librosa.load(filename, sr=44100, mono=True) num_samples = audio.shape[0] num_portions = (num_samples - param.overlap) // (param.That * (param.L - param.overlap)) num_samples_output = num_portions * param.That * (param.L - param.overlap) num_samples = num_samples_output + param.overlap if FLAGS.interpolate: def filter_gen(n): if n < param.overlap: return n / param.overlap elif n > param.L - param.overlap: return (param.L - n) / param.overlap else: return 1 output_filter = np.array([filter_gen(n) for n in range(param.L)]) print("predicting...") audio = audio[:num_samples] model_input = np.zeros((num_portions, param.That, param.L)) for i in range(num_portions): for j in range(param.That): begin = (i * param.That + j) * (param.L - param.overlap) end = begin + param.L model_input[i][j] = audio[begin:end] separated = model.predict(model_input) separated = np.transpose(separated, (1, 0, 2, 3)) if FLAGS.interpolate: separated = output_filter * separated overlapped = separated[:, :, :, (param.L - param.overlap):] overlapped = np.pad(overlapped, pad_width=((0, 0), (0, 0), (0, 0), (0, param.L - 2 * param.overlap)), mode="constant", constant_values=0) overlapped = np.reshape(overlapped, (param.C, num_samples_output)) overlapped[:, 1:] = overlapped[:, :-1] overlapped[:, 0] = 0 separated = separated[:, :, :, :(param.L - param.overlap)] separated = np.reshape(separated, (param.C, num_samples_output)) if FLAGS.interpolate: separated += overlapped print("saving...") for idx, stem in enumerate(Dataset.STEMS): sf.write(f"{title}_{stem}.wav", separated[idx], sr)