Ejemplo n.º 1
0
def main(argv):
    model = ConvTasNet.make(get_model_param())
    dataset = Dataset(FLAGS.dataset_path, max_decoded=FLAGS.max_decoded)
    checkpoint_dir = FLAGS.checkpoint

    epoch = 0
    if path.exists(checkpoint_dir):
        checkpoints = [name for name in listdir(
            checkpoint_dir) if "ckpt" in name]
        checkpoints.sort()
        checkpoint_name = checkpoints[-1].split(".")[0]
        epoch = int(checkpoint_name) + 1
        model.load_weights(f"{checkpoint_dir}/{checkpoint_name}.ckpt")

    epochs_to_inc = FLAGS.epochs
    while epochs_to_inc == None or epochs_to_inc > 0:
        print(f"Epoch: {epoch}")
        history = model.fit(dataset.make_dataset(get_dataset_param()))
        model.save_weights(f"{checkpoint_dir}/{epoch:05d}.ckpt")
        epoch += 1
        if epochs_to_inc != None:
            epochs_to_inc -= 1
        model.param.save(f"{checkpoint_dir}/config.txt")
        model.save(f"{checkpoint_dir}/model")
Ejemplo n.º 2
0
def main(argv):
    checkpoint_dir = FLAGS.checkpoint
    if not path.exists(checkpoint_dir):
        raise ValueError(f"'{checkpoint_dir}' does not exist")

    checkpoints = [name for name in listdir(checkpoint_dir) if "ckpt" in name]
    if not checkpoints:
        raise ValueError(f"No checkpoint exists")
    checkpoints.sort()
    checkpoint_name = checkpoints[-1].split(".")[0]

    param = ConvTasNetParam.load(f"{checkpoint_dir}/config.txt")
    model = ConvTasNet.make(param)
    model.load_weights(f"{checkpoint_dir}/{checkpoint_name}.ckpt")

    video_id = FLAGS.video_id

    ydl_opts = {
        "format":
        "bestaudio/best",
        "postprocessors": [{
            "key": "FFmpegExtractAudio",
            "preferredcodec": "wav",
            "preferredquality": "44100",
        }],
        "outtmpl":
        "%(title)s.wav",
        "progress_hooks": [youtube_dl_hook],
    }

    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(video_id, download=False)
        status = ydl.download([video_id])

    title = info.get("title", None)
    filename = title + ".wav"
    audio, sr = librosa.load(filename, sr=44100, mono=True)

    num_samples = audio.shape[0]

    num_portions = (num_samples - param.overlap) // (param.That *
                                                     (param.L - param.overlap))

    num_samples_output = num_portions * param.That * (param.L - param.overlap)

    num_samples = num_samples_output + param.overlap

    if FLAGS.interpolate:

        def filter_gen(n):
            if n < param.overlap:
                return n / param.overlap
            elif n > param.L - param.overlap:
                return (param.L - n) / param.overlap
            else:
                return 1

        output_filter = np.array([filter_gen(n) for n in range(param.L)])

    print("predicting...")

    audio = audio[:num_samples]

    model_input = np.zeros((num_portions, param.That, param.L))

    for i in range(num_portions):
        for j in range(param.That):
            begin = (i * param.That + j) * (param.L - param.overlap)
            end = begin + param.L
            model_input[i][j] = audio[begin:end]

    separated = model.predict(model_input)
    separated = np.transpose(separated, (1, 0, 2, 3))

    if FLAGS.interpolate:
        separated = output_filter * separated
        overlapped = separated[:, :, :, (param.L - param.overlap):]
        overlapped = np.pad(overlapped,
                            pad_width=((0, 0), (0, 0), (0, 0),
                                       (0, param.L - 2 * param.overlap)),
                            mode="constant",
                            constant_values=0)
        overlapped = np.reshape(overlapped, (param.C, num_samples_output))
        overlapped[:, 1:] = overlapped[:, :-1]
        overlapped[:, 0] = 0

    separated = separated[:, :, :, :(param.L - param.overlap)]
    separated = np.reshape(separated, (param.C, num_samples_output))

    if FLAGS.interpolate:
        separated += overlapped

    print("saving...")

    for idx, stem in enumerate(Dataset.STEMS):
        sf.write(f"{title}_{stem}.wav", separated[idx], sr)