Пример #1
0
def test_forward(ver, gene, n_warmup=11, n_work=121):
    import tensorflow as tf
    import numpy as np
    from mir_util import infer
    import config as cfg
    import sys, time

    batch_size = 1
    n_feature = cfg.frame_size // 2

    graph = tf.Graph()
    with graph.as_default():
        # Model
        print("Initialize network")
        with tf.device("/device:GPU:0"):
            p_input = tf.random.uniform((batch_size, 64, n_feature, 1),
                                        dtype=tf.float32,
                                        name="p_input")
            v_pred = tf.clip_by_value(
                infer(p_input, 2, False, ver=ver, gene=gene), 0.0,
                1.0) * p_input

        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            # Initialized, Load state
            sess.run(tf.compat.v1.global_variables_initializer())
            for step in range(n_warmup + n_work):
                sess.run(v_pred)
                if step == n_warmup:
                    t = time.time()
        t_eval = (time.time() - t) / n_work
        return t_eval
Пример #2
0
def test_backward(ver, gene, n_warmup=11, n_work=121):
    import tensorflow as tf
    import numpy as np
    from mir_util import infer
    import config as cfg
    import sys, time

    batch_size = 1
    n_feature = cfg.frame_size // 2

    graph = tf.Graph()
    with graph.as_default():
        # Model
        p_input = tf.random.uniform((batch_size, 64, n_feature, 1),
                                    dtype=tf.float32,
                                    name="p_input")
        p_target = tf.random.uniform((batch_size, 64, n_feature, 2),
                                     dtype=tf.float32,
                                     name="p_target")
        v_pred = infer(p_input, 2, True, ver=ver, gene=gene) * p_input
        v_loss = tf.reduce_mean(input_tensor=tf.abs(p_target - v_pred),
                                name="loss0")
        op_optim = tf.compat.v1.train.AdamOptimizer(
            learning_rate=1e-4).minimize(v_loss)

        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            for i_step in range(n_warmup + n_work):
                sess.run([v_loss, op_optim])
                if i_step == n_warmup:
                    t = time.time()
        t_train = (time.time() - t) / n_work
        return t_train
Пример #3
0
def test_backward(ver, gene, n_warmup=11, n_work=121):
    import tensorflow as tf
    import numpy as np
    from mir_util import infer
    import config as cfg
    import sys, time
    import netop
    sys.is_train = True

    batch_size = 1
    n_feature = 5644 // 2

    graph = tf.Graph()
    with graph.as_default():
        # Model
        p_input = tf.random.uniform((batch_size, 64, n_feature, 1),
                                    dtype=tf.float32,
                                    name="p_input")
        p_target = tf.random.uniform((batch_size, 64, n_feature, 2),
                                     dtype=tf.float32,
                                     name="p_target")
        v_pred = infer(p_input, 2, True, ver=ver, gene=gene)
        v_loss = tf.reduce_mean(input_tensor=tf.abs(p_target - v_pred),
                                name="loss0")
        op_optim = tf.compat.v1.train.AdamOptimizer(
            learning_rate=1e-4).minimize(v_loss)

        n_param = netop.count_parameter()
        n_forward_flop = tf.compat.v1.profiler.profile(
            graph,
            options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation(
            )).total_float_ops
        print(" BWD :Total {:,} parameters in total".format(n_param))
        print(" BWD :Forward + backward operation needs {:,} FLOPS".format(
            n_forward_flop))

        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            for i_step in range(n_warmup + n_work):
                sess.run([v_loss, op_optim])
                if i_step == n_warmup:
                    t = time.time()
        t_train = (time.time() - t) / n_work
        return t_train
Пример #4
0
def test_forward(ver, gene, n_warmup=11, n_work=121):
    import tensorflow as tf
    import numpy as np
    from mir_util import infer
    import config as cfg
    import sys, time
    import netop
    sys.path.append("../lib")
    sys.is_train = False

    batch_size = 1
    n_feature = 5644 // 2

    graph = tf.Graph()
    with graph.as_default():
        # Model
        print("Initialize network")
        with tf.device("/device:GPU:0"):
            p_input = tf.random.uniform((batch_size, 64, n_feature, 1),
                                        dtype=tf.float32,
                                        name="p_input")
            v_pred = infer(p_input, 2, False, ver=ver, gene=gene)

        n_param = netop.count_parameter()
        n_forward_flop = tf.compat.v1.profiler.profile(
            graph,
            options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation(
            )).total_float_ops
        print(" FWD :Total {:,} parameters in total".format(n_param))
        print(
            " FWD :Forward operation needs {:,} FLOPS".format(n_forward_flop))

        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            # Initialized, Load state
            sess.run(tf.compat.v1.global_variables_initializer())
            for step in range(n_warmup + n_work):
                sess.run(v_pred)
                if step == n_warmup:
                    t = time.time()
        t_eval = (time.time() - t) / n_work
        return t_eval
Пример #5
0
def count(ver, gene_int, n_ch):
  import tensorflow as tf
  from mir_util import infer
  import config as cfg
  import netop

  with cfg.ConfigBoundary(gene_ver=ver, gene_value=gene_int):
    batch_size = 1
    graph = tf.Graph()
    run_meta = tf.compat.v1.RunMetadata()
    with graph.as_default():
      x_mixed = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, 64, cfg.frame_size // 2, 1), name="x_mixed")
      y_mixed = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, 64, cfg.frame_size // 2, n_ch), name="y_mixed")
      y_pred = infer(x_mixed, n_ch, True)
      n_forward_flop = tf.compat.v1.profiler.profile(graph, run_meta=run_meta, cmd="op", options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()).total_float_ops
      y_output = tf.multiply(x_mixed, y_pred)
      loss_fn = tf.reduce_mean(input_tensor=tf.abs(y_mixed - y_output) , name="loss0")
      global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
      optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-4).minimize(loss_fn, global_step=global_step)

      n_total_flop = tf.compat.v1.profiler.profile(graph, run_meta=run_meta, cmd="op", options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()).total_float_ops
      total_parameters = netop.count_parameter()
      return total_parameters, n_forward_flop
Пример #6
0
    assert os.path.exists(cache_path), "Dataset cache not found"
    print("* Read cached spectrograms")
    mixed_list, vocal_list, inst_list, n_sample = pickle.load(
        open(cache_path, "rb"))
    print("* Number of training examples: %d" % (n_sample, ))

    # Model
    print("* Initialize network")
    tf.compat.v1.random.set_random_seed(0x41526941)
    p_input = tf.compat.v1.placeholder(tf.float32,
                                       shape=(batch_size, 64, n_feature, 1),
                                       name="p_input")
    p_target = tf.compat.v1.placeholder(tf.float32,
                                        shape=(batch_size, 64, n_feature, 2),
                                        name="p_target")
    v_pred = infer(p_input, 2, True)
    if isinstance(v_pred, list):
        v_loss = 0
        for y in v_pred:
            v_loss += tf.reduce_mean(input_tensor=tf.abs(p_target -
                                                         (y * p_input)))
    else:
        v_pred *= p_input
        v_loss = tf.reduce_mean(input_tensor=tf.abs(p_target - v_pred))
    # Loss, Optimizer
    v_global_step = tf.Variable(0,
                                dtype=tf.int32,
                                trainable=False,
                                name="v_global_step")
    p_lr_fac = tf.compat.v1.placeholder(tf.float32, name="p_lr_fac")
    v_lr = p_lr_fac * tf.compat.v1.train.cosine_decay_restarts(
Пример #7
0
def main():
    import multiprocessing as mp
    import numpy as np
    import tensorflow as tf
    import os, sys
    from mir_util import infer, to_spec, to_wav_file
    import scipy.signal as sp
    import config as cfg
    sys.path.append("../lib")
    from eval_util import bss_eval
    from common import loadWav
    import redirect, simpleopt

    step_idx = int(simpleopt.get("step"))
    n_eval = simpleopt.get("first", None)
    with cfg.ConfigBoundary():
        batch_size = 1
        n_feature = cfg.frame_size // 2

        # Model
        print("* Initialize network")
        p_input = tf.compat.v1.placeholder(tf.float32,
                                           shape=(batch_size, 64, n_feature,
                                                  1),
                                           name="p_input")
        v_pred = infer(p_input, 2, False)
        if isinstance(v_pred, list):
            v_pred = v_pred[-1]
        v_pred = tf.clip_by_value(v_pred, 0.0, 1.0) * p_input

        x_input = np.zeros((batch_size, 64, n_feature, 1), dtype=np.float32)
        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            # Initialized, Load state
            sess.run(tf.compat.v1.global_variables_initializer())

            print("* Load checkpoint")
            ckpt_path = os.path.join(cfg.MIR2Config.ckpt_path,
                                     "checkpoint-%d" % (step_idx, ))
            tf.compat.v1.train.Saver().restore(sess, ckpt_path)
            print(" :Loaded: `%s`" % (ckpt_path, ))

            os.makedirs("./eval_output", exist_ok=True)
            name_list = []
            ret_list = []
            with mp.Pool(processes=1, initializer=worker_main) as pool:
                for (root, _, file_list) in os.walk(cfg.mir_wav_path):
                    file_list = sorted(f for f in file_list if not (
                        f.startswith("abjones") or f.startswith("amy")))
                    if n_eval is not None:
                        file_list = file_list[:int(n_eval)]
                    for i_file, filename in enumerate(file_list):
                        print("[%03d/%03d] SEND: `%s`" % (
                            i_file + 1,
                            len(file_list),
                            filename,
                        ))
                        name_list.append(filename)
                        path = os.path.join(root, filename)

                        mixed_wav, sr_orig = loadWav(path)
                        gt_wav_vocal = mixed_wav[:, 1]
                        gt_wav_inst = mixed_wav[:, 0]
                        mixed_wav = np.sum(mixed_wav, axis=1)

                        mixed_wav_orig = mixed_wav
                        gt_wav_vocal_orig = gt_wav_vocal
                        gt_wav_inst_orig = gt_wav_inst

                        gt_wav_vocal = sp.resample_poly(
                            gt_wav_vocal, cfg.sr, sr_orig).astype(np.float32)
                        gt_wav_inst = sp.resample_poly(
                            gt_wav_inst, cfg.sr, sr_orig).astype(np.float32)
                        mixed_wav = sp.resample_poly(
                            mixed_wav, cfg.sr, sr_orig).astype(np.float32)

                        mixed_spec = to_spec(mixed_wav)
                        mixed_spec_mag = np.abs(mixed_spec)
                        mixed_spec_phase = np.angle(mixed_spec)
                        max_tmp = np.max(mixed_spec_mag)
                        mixed_spec_mag = mixed_spec_mag / max_tmp

                        src_len = mixed_spec_mag.shape[0]
                        start_idx = 0
                        y_est_inst = np.zeros((src_len, n_feature),
                                              dtype=np.float32)
                        y_est_vocal = np.zeros((src_len, n_feature),
                                               dtype=np.float32)
                        while start_idx + 64 < src_len:
                            x_input[0, :, :,
                                    0] = mixed_spec_mag[start_idx:start_idx +
                                                        64, :n_feature]
                            y_output = sess.run(v_pred,
                                                feed_dict={p_input: x_input})
                            if start_idx == 0:
                                y_est_inst[start_idx:start_idx +
                                           64, :] = y_output[0, :, :, 0]
                                y_est_vocal[start_idx:start_idx +
                                            64, :] = y_output[0, :, :, 1]
                            else:
                                y_est_inst[start_idx + 16:start_idx +
                                           48, :] = y_output[0, 16:48, :, 0]
                                y_est_vocal[start_idx + 16:start_idx +
                                            48, :] = y_output[0, 16:48, :, 1]
                            start_idx += 32

                        x_input[0, :, :,
                                0] = mixed_spec_mag[src_len -
                                                    64:src_len, :n_feature]
                        y_output = sess.run(v_pred,
                                            feed_dict={p_input: x_input})
                        src_start = src_len - start_idx - 16
                        y_est_inst[start_idx +
                                   16:src_len, :] = y_output[0, 64 -
                                                             src_start:64, :,
                                                             0]
                        y_est_vocal[start_idx +
                                    16:src_len, :] = y_output[0, 64 -
                                                              src_start:64, :,
                                                              1]

                        y_est_inst *= max_tmp
                        y_est_vocal *= max_tmp
                        y_wav_inst = to_wav_file(
                            y_est_inst, mixed_spec_phase[:, :n_feature])
                        y_wav_vocal = to_wav_file(
                            y_est_vocal, mixed_spec_phase[:, :n_feature])
                        #saveWav("inst.wav", y_wav_inst, cfg.sr)
                        #saveWav("vocal.wav", y_wav_vocal, cfg.sr)

                        # upsample to original samprate
                        y_wav_inst_orig = sp.resample_poly(
                            y_wav_inst, sr_orig, cfg.sr).astype(np.float32)
                        y_wav_vocal_orig = sp.resample_poly(
                            y_wav_vocal, sr_orig, cfg.sr).astype(np.float32)
                        ret_list.append(
                            pool.apply_async(bss_eval, (
                                mixed_wav_orig,
                                gt_wav_inst_orig,
                                gt_wav_vocal_orig,
                                y_wav_inst_orig,
                                y_wav_vocal_orig,
                            )))
                with redirect.ConsoleAndFile(
                        "./eval_output/mir2_%s_%d_step%d.txt" %
                    (cfg.gene_ver, cfg.gene_value, step_idx)) as r:
                    gnsdr = 0.0
                    gsir = 0.0
                    gsar = 0.0
                    total_len = 0
                    for name, ret in zip(name_list, ret_list):
                        nsdr, sir, sar, lens = ret.get()
                        printstr = name + " " + str(nsdr) + " " + str(
                            sir) + " " + str(sar)
                        r.print(printstr)
                        total_len += lens
                        gnsdr += nsdr * lens
                        gsir += sir * lens
                        gsar += sar * lens
                    r.print("Final results")
                    r.print("GNSDR [Accompaniments, voice]")
                    r.print(gnsdr / total_len)
                    r.print("GSIR [Accompaniments, voice]")
                    r.print(gsir / total_len)
                    r.print("GSAR [Accompaniments, voice]")
                    r.print(gsar / total_len)
Пример #8
0
def main():
    import multiprocessing as mp
    import numpy as np
    import tensorflow as tf
    import os, sys
    import config as cfg
    #import librosa
    from mir_util import infer, to_spec, to_wav_file
    import scipy.signal as sp
    sys.path.append("../lib")
    from eval_util import bss_eval_sdr_framewise
    from common import loadWav
    import redirect, simpleopt
    import pandas as pd

    step_idx = int(simpleopt.get("step"))
    n_eval = simpleopt.get("first", None)
    with cfg.ConfigBoundary():
        batch_size = 1
        n_feature = cfg.frame_size // 2

        # Model
        print("* Initialize network")
        p_input = tf.compat.v1.placeholder(tf.float32,
                                           shape=(batch_size, 64, n_feature,
                                                  1),
                                           name="p_input")
        v_pred = infer(p_input, 2, False)
        if isinstance(v_pred, list):
            v_pred = v_pred[-1]
        v_pred = tf.clip_by_value(v_pred, 0.0, 1.0) * p_input

        x_input = np.zeros((batch_size, 64, n_feature, 1), dtype=np.float32)
        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            # Initialized, Load state
            sess.run(tf.compat.v1.global_variables_initializer())

            print("* Load checkpoint")
            ckpt_path = os.path.join(cfg.DSD2Config.ckpt_path,
                                     "checkpoint-%d" % (step_idx, ))
            tf.compat.v1.train.Saver().restore(sess, ckpt_path)
            print(" :Loaded: `%s`" % (ckpt_path, ))

            os.makedirs("./eval_output", exist_ok=True)
            name_list = []
            ret_list = []
            with mp.Pool(processes=8, initializer=worker_main) as pool:
                for (root, dir_list, _) in os.walk(
                        os.path.join(cfg.dsd_path, "Mixtures", "Test")):
                    dir_list = sorted(dir_list)
                    if n_eval is not None:
                        dir_list = dir_list[:int(n_eval)]
                    for i_dir, d in enumerate(dir_list):
                        print("[%02d/%02d] STG1: `%s`" % (
                            i_dir + 1,
                            len(dir_list),
                            d,
                        ))
                        name_list.append(d)

                        filename_vocal = os.path.join(cfg.dsd_path, "Sources",
                                                      "Test", d, "vocals.wav")
                        filename_mix = os.path.join(cfg.dsd_path, "Mixtures",
                                                    "Test", d, "mixture.wav")

                        import time
                        t = time.time()
                        mixed_wav_orig, sr_orig = loadWav(
                            filename_mix
                        )  #librosa.load(filename_mix, sr=None, mono=True)
                        mixed_wav_orig = np.sum(mixed_wav_orig, axis=1)
                        gt_wav_vocal_orig, _ = loadWav(
                            filename_vocal
                        )  #librosa.load(filename_vocal, sr=None, mono=True)[0]
                        gt_wav_vocal_orig = np.sum(gt_wav_vocal_orig, axis=1)
                        gt_wav_inst_orig = mixed_wav_orig - gt_wav_vocal_orig

                        mixed_wav = sp.resample_poly(
                            mixed_wav_orig, cfg.sr, sr_orig
                        ).astype(
                            np.float32
                        )  #librosa.load(filename_mix, sr=cfg.sr, mono=True)[0]
                        gt_wav_vocal = sp.resample_poly(
                            gt_wav_vocal_orig, cfg.sr, sr_orig
                        ).astype(
                            np.float32
                        )  #librosa.load(filename_vocal, sr=cfg.sr, mono=True)[0]
                        gt_wav_inst = mixed_wav - gt_wav_vocal
                        mixed_spec = to_spec(mixed_wav)
                        mixed_spec_mag = np.abs(mixed_spec)
                        mixed_spec_phase = np.angle(mixed_spec)
                        max_tmp = np.max(mixed_spec_mag)
                        mixed_spec_mag = mixed_spec_mag / max_tmp

                        src_len = mixed_spec_mag.shape[0]
                        start_idx = 0
                        y_est_inst = np.zeros((src_len, n_feature),
                                              dtype=np.float32)
                        y_est_vocal = np.zeros((src_len, n_feature),
                                               dtype=np.float32)
                        while start_idx + 64 < src_len:
                            x_input[0, :, :,
                                    0] = mixed_spec_mag[start_idx:start_idx +
                                                        64, :n_feature]
                            y_output = sess.run(v_pred,
                                                feed_dict={p_input: x_input})
                            if start_idx == 0:
                                y_est_inst[start_idx:start_idx +
                                           64, :] = y_output[0, :, :, 0]
                                y_est_vocal[start_idx:start_idx +
                                            64, :] = y_output[0, :, :, 1]
                            else:
                                y_est_inst[start_idx + 16:start_idx +
                                           48, :] = y_output[0, 16:48, :, 0]
                                y_est_vocal[start_idx + 16:start_idx +
                                            48, :] = y_output[0, 16:48, :, 1]
                            start_idx += 32

                        x_input[0, :, :,
                                0] = mixed_spec_mag[src_len -
                                                    64:src_len, :n_feature]
                        y_output = sess.run(v_pred,
                                            feed_dict={p_input: x_input})
                        src_start = src_len - start_idx - 16
                        y_est_inst[start_idx +
                                   16:src_len, :] = y_output[0, 64 -
                                                             src_start:64, :,
                                                             0]
                        y_est_vocal[start_idx +
                                    16:src_len, :] = y_output[0, 64 -
                                                              src_start:64, :,
                                                              1]

                        y_est_inst *= max_tmp
                        y_est_vocal *= max_tmp
                        y_wav_inst = to_wav_file(
                            y_est_inst, mixed_spec_phase[:, :n_feature])
                        y_wav_vocal = to_wav_file(
                            y_est_vocal, mixed_spec_phase[:, :n_feature])
                        #saveWav("inst.wav", y_wav_inst, cfg.sr)
                        #saveWav("vocal.wav", y_wav_vocal, cfg.sr)

                        #upsample to original SR
                        y_wav_inst_orig = sp.resample_poly(
                            y_wav_inst, sr_orig, cfg.sr).astype(
                                np.float32
                            )  #librosa.resample(y_wav_inst, cfg.sr, sr_orig)
                        y_wav_vocal_orig = sp.resample_poly(
                            y_wav_vocal, sr_orig, cfg.sr).astype(
                                np.float32
                            )  #librosa.resample(y_wav_vocal, cfg.sr, sr_orig)

                        ret_list.append(
                            pool.apply_async(bss_eval_sdr_framewise, (
                                np.array([gt_wav_inst_orig, gt_wav_vocal_orig],
                                         dtype=np.float32),
                                np.array([y_wav_inst_orig, y_wav_vocal_orig],
                                         dtype=np.float32),
                            )))

                head_list = [
                    "method", "track", "target", "metric", "score", "time"
                ]
                row_list = []
                out_path = "./old_fw/dsd2_%s_%d_step%d.json" % (
                    cfg.gene_ver, cfg.gene_value, step_idx)
                method_name = "dsd2_%s_%d_step%d" % (cfg.gene_ver,
                                                     cfg.gene_value, step_idx)
                for name, ret in zip(name_list, ret_list):
                    print(name)
                    sdr, sir, sar = ret.get()
                    for i, v in enumerate(sdr[0]):
                        row_list.append((
                            method_name,
                            name,
                            "accompaniment",
                            "SDR",
                            v,
                            i,
                        ))
                    for i, v in enumerate(sir[0]):
                        row_list.append((
                            method_name,
                            name,
                            "accompaniment",
                            "SIR",
                            v,
                            i,
                        ))
                    for i, v in enumerate(sar[0]):
                        row_list.append((
                            method_name,
                            name,
                            "accompaniment",
                            "SAR",
                            v,
                            i,
                        ))

                    for i, v in enumerate(sdr[1]):
                        row_list.append((
                            method_name,
                            name,
                            "vocals",
                            "SDR",
                            v,
                            i,
                        ))
                    for i, v in enumerate(sir[1]):
                        row_list.append((
                            method_name,
                            name,
                            "vocals",
                            "SIR",
                            v,
                            i,
                        ))
                    for i, v in enumerate(sar[1]):
                        row_list.append((
                            method_name,
                            name,
                            "vocals",
                            "SAR",
                            v,
                            i,
                        ))
                out = pd.DataFrame(row_list, columns=head_list).reset_index()
                print(out)
                out.to_json(out_path)
Пример #9
0
ver = simpleopt.get("ver")
gene = simpleopt.get("gene")

with cfg.ConfigBoundary():
    net_config, ch_list = {
        "mus2": (cfg.MUS2FConfig, ("inst", "vocal")),
    }[dataset_type]
    n_ch = len(ch_list)

    # Model
    print("* Initialize network model")
    p_input = tf.compat.v1.placeholder(tf.float32,
                                       shape=(1, 64, n_feature, 1),
                                       name="p_input")
    v_pred = infer(p_input, 2, False)
    if isinstance(v_pred, list):
        v_pred = v_pred[-1]

    x_input = np.zeros((1, 64, n_feature, 1), dtype=np.float32)
    with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
        sess.run(tf.compat.v1.global_variables_initializer())

        print("* Load checkpoint")
        ckpt_path = os.path.join(net_config.ckpt_path,
                                 "checkpoint-%d" % (ckpt_step, ))
        tf.compat.v1.train.Saver().restore(sess, ckpt_path)
        print(" :Loaded: `%s`" % (ckpt_path, ))

        for mix_path in mix_path_list:
            print("* Compute `%s`" % (mix_path, ))
Пример #10
0
def main_estimate(pool):
    import numpy as np
    import tensorflow as tf
    import os, sys, pathlib
    sys.path.append("../lib")
    import config as cfg
    #import librosa
    from mir_util import infer, to_spec, to_wav_file
    import scipy.signal as sp
    import musdb, museval
    import simpleopt
    sys.is_train = False

    step_idx = int(simpleopt.get("step"))
    n_eval = simpleopt.get("first", None)
    if n_eval is not None:
        n_eval = int(n_eval)
        assert n_eval > 0

    sound_sample_root = simpleopt.get("sound-out", None)
    source = simpleopt.get("source")
    if source == "vocals":
        source = None

    with cfg.ConfigBoundary():
        if source is None:
            model_name = "mus2f_%s_%d_step%d" % (
                cfg.gene_ver,
                cfg.gene_value,
                step_idx,
            )
        else:
            model_name = "mus2f_%s_%d_step%d_%s" % (
                cfg.gene_ver,
                cfg.gene_value,
                step_idx,
                source,
            )
        model_name_nosrc = "mus2f_%s_%d_step%d" % (
            cfg.gene_ver,
            cfg.gene_value,
            step_idx,
        )
        if sound_sample_root is None:
            sound_sample_root = "./sound_output_mus2f/{}".format(
                model_name_nosrc)
        pathlib.Path(sound_sample_root).mkdir(parents=True, exist_ok=True)
        ckpt_path = cfg.MUS2FConfig.ckpt_path
        if source != "vocals":
            ckpt_path = "{}_{}".format(ckpt_path, source)

        batch_size = 1
        n_feature = 5644 // 2

        # Model
        print("* Initialize network")
        p_input = tf.compat.v1.placeholder(tf.float32,
                                           shape=(batch_size, 64, n_feature,
                                                  1),
                                           name="p_input")
        v_pred = infer(p_input, 2, False)
        if isinstance(v_pred, list):
            v_pred = v_pred[-1]

        with tf.compat.v1.Session(config=cfg.sess_cfg) as sess:
            # Initialized, Load state
            sess.run(tf.compat.v1.global_variables_initializer())

            print("* Load checkpoint")
            ckpt_path = os.path.join(ckpt_path, "checkpoint-%d" % (step_idx, ))
            tf.compat.v1.train.Saver().restore(sess, ckpt_path)
            print(" :Loaded: `%s`" % (ckpt_path, ))

            os.makedirs("./eval_output", exist_ok=True)
            name_list = []
            ret_list = []

            mus = musdb.DB(root=cfg.mus_root_path,
                           download=False,
                           subsets="test",
                           is_wav=True)
            mus_trk_list = list(mus.tracks)
            mus_trk_list.sort(key=lambda x: x.name)
            assert len(mus_trk_list) > 0
            if n_eval is not None:
                mus_trk_list = mus_trk_list[:n_eval]

            results = museval.EvalStore()

            for i_song, track in enumerate(mus_trk_list):
                print("[%02d/%02d] Estimate: `%s`" % (
                    i_song + 1,
                    len(mus_trk_list),
                    track.name,
                ))
                voc_ch_list = []
                inst_ch_list = []
                for i_channel in range(2):
                    print(" :Channel #%d" % (i_channel, ))
                    name_list.append(track.name + " Channel %d" %
                                     (i_channel, ))

                    mixed_wav = track.audio[:, i_channel]
                    mixed_spec = to_spec(mixed_wav,
                                         len_frame=5644,
                                         len_hop=5644 // 4)
                    mixed_spec_mag = np.abs(mixed_spec)
                    mixed_spec_phase = np.angle(mixed_spec)
                    max_tmp = np.max(mixed_spec_mag)
                    mixed_spec_mag = mixed_spec_mag / max_tmp

                    src_len = mixed_spec_mag.shape[0]
                    start_idx = 0
                    y_est_inst = np.zeros((src_len, n_feature),
                                          dtype=np.float32)
                    y_est_vocal = np.zeros((src_len, n_feature),
                                           dtype=np.float32)
                    x_input = np.zeros((batch_size, 64, n_feature, 1),
                                       dtype=np.float32)
                    while start_idx + 64 < src_len:
                        x_input[0, :, :,
                                0] = mixed_spec_mag[start_idx:start_idx +
                                                    64, :n_feature]
                        y_output = sess.run(v_pred,
                                            feed_dict={p_input: x_input})
                        if start_idx == 0:
                            y_est_inst[start_idx:start_idx +
                                       64, :] = y_output[0, :, :, 0]
                            y_est_vocal[start_idx:start_idx +
                                        64, :] = y_output[0, :, :, 1]
                        else:
                            y_est_inst[start_idx + 16:start_idx +
                                       48, :] = y_output[0, 16:48, :, 0]
                            y_est_vocal[start_idx + 16:start_idx +
                                        48, :] = y_output[0, 16:48, :, 1]
                        start_idx += 32

                    x_input[0, :, :,
                            0] = mixed_spec_mag[src_len -
                                                64:src_len, :n_feature]
                    y_output = sess.run(v_pred, feed_dict={p_input: x_input})
                    src_start = src_len - start_idx - 16
                    y_est_inst[start_idx +
                               16:src_len, :] = y_output[0,
                                                         64 - src_start:64, :,
                                                         0]
                    y_est_vocal[start_idx +
                                16:src_len, :] = y_output[0,
                                                          64 - src_start:64, :,
                                                          1]

                    y_est_inst *= max_tmp
                    y_est_vocal *= max_tmp
                    y_wav_inst = to_wav_file(y_est_inst,
                                             mixed_spec_phase[:, :n_feature],
                                             len_hop=5644 // 4)
                    y_wav_vocal = to_wav_file(y_est_vocal,
                                              mixed_spec_phase[:, :n_feature],
                                              len_hop=5644 // 4)

                    voc_ch_list.append(y_wav_vocal.reshape(
                        y_wav_vocal.size, 1))
                    inst_ch_list.append(y_wav_inst.reshape(y_wav_inst.size, 1))
                    del y_wav_inst, y_wav_vocal, y_est_inst, y_est_vocal, src_start, x_input, y_output, mixed_spec_mag, max_tmp, mixed_spec_phase, mixed_spec, mixed_wav
                estimates = {
                    source: np.concatenate(voc_ch_list, axis=1),
                }
                del voc_ch_list, inst_ch_list
                if sound_sample_root:
                    mus.save_estimates(estimates, track, sound_sample_root)
                del estimates, i_song, track