Пример #1
0
def eval(model, eval_data, sess):
    mixed_wav, src1_wav, src2_wav, _ = eval_data.next_wavs(
        EvalConfig.SECONDS, EvalConfig.NUM_EVAL)

    mixed_spec = to_spectrogram(mixed_wav)
    mixed_mag = get_magnitude(mixed_spec)

    src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
    src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)

    src1_batch, _ = model.spec_to_batch(src1_mag)
    src2_batch, _ = model.spec_to_batch(src2_mag)
    mixed_batch, _ = model.spec_to_batch(mixed_mag)

    pred_src1_mag, pred_src2_mag = sess.run(
        model(), feed_dict={model.x_mixed: mixed_batch})

    mixed_phase = get_phase(mixed_spec)
    seq_len = mixed_phase.shape[-1]
    pred_src1_mag = model.batch_to_spec(pred_src1_mag,
                                        EvalConfig.NUM_EVAL)[:, :, :seq_len]
    pred_src2_mag = model.batch_to_spec(pred_src2_mag,
                                        EvalConfig.NUM_EVAL)[:, :, :seq_len]

    # Time-frequency masking
    mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
    # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
    mask_src2 = 1. - mask_src1
    pred_src1_mag = mixed_mag * mask_src1
    pred_src2_mag = mixed_mag * mask_src2

    # (magnitude, phase) -> spectrogram -> wav
    if EvalConfig.GRIFFIN_LIM:
        pred_src1_wav = to_wav_mag_only(pred_src1_mag,
                                        init_phase=mixed_phase,
                                        num_iters=EvalConfig.GRIFFIN_LIM_ITER)
        pred_src2_wav = to_wav_mag_only(pred_src2_mag,
                                        init_phase=mixed_phase,
                                        num_iters=EvalConfig.GRIFFIN_LIM_ITER)
    else:
        pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
        pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

    # Compute BSS metrics
    gnsdr, gsir, gsar = bss_eval_global(mixed_wav, src1_wav, src2_wav,
                                        pred_src1_wav, pred_src2_wav,
                                        EvalConfig.NUM_EVAL)
    return gnsdr, gsir, gsar
Пример #2
0
def process(raw_data, n_frames, n_channels, sample_width, sample_rate):
    # Model
    model = Model()
    global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')

    # channels is int32 type
    channels = interpret_wav(raw_data, n_frames, n_channels, sample_width, True)

    mixed_wav = channels[0]
    with tf.Session(config=EvalConfig.session_conf) as sess:
        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, EvalConfig.CKPT_PATH)

        mixed_spec = to_spectrogram(mixed_wav)
        mixed_mag = get_magnitude(mixed_spec)
        mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
        mixed_phase = get_phase(mixed_spec)

        (pred_src1_mag, pred_src2_mag) = sess.run(model(), feed_dict={model.x_mixed: mixed_batch})

        seq_len = mixed_phase.shape[-1]
        pred_src1_mag = model.batch_to_spec(pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
        pred_src2_mag = model.batch_to_spec(pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

        # Time-frequency masking
        mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
        mask_src2 = 1. - mask_src1
        pred_src2_mag = mixed_mag * mask_src2

        pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

        processed_signal.put(pred_src2_wav[0])

    return pred_src2_wav[0].astype(channels.dtype)
Пример #3
0
def process(input_signal, processed_signal):
    # Model
    model = Model()
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    with tf.Session(config=EvalConfig.session_conf) as sess:
        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, EvalConfig.CKPT_PATH)

        while (True):
            start = time.time()
            mixed_wav = input_signal.get()

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)
            mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
            mixed_phase = get_phase(mixed_spec)

            assert (np.all(
                np.equal(model.batch_to_spec(mixed_batch, EvalConfig.NUM_EVAL),
                         padded_mixed_mag)))

            input_signal.task_done()
            middle = time.time()
            (pred_src1_mag,
             pred_src2_mag) = sess.run(model(),
                                       feed_dict={model.x_mixed: mixed_batch})

            modeltime = time.time()
            seq_len = mixed_phase.shape[-1]
            pred_src1_mag = model.batch_to_spec(
                pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
            pred_src2_mag = model.batch_to_spec(
                pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
            middle2 = time.time()

            # Time-frequency masking
            mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
            mask_src2 = 1. - mask_src1
            pred_src1_mag = mixed_mag * mask_src1
            pred_src2_mag = mixed_mag * mask_src2

            # (magnitude, phase) -> spectrogram -> wav
            pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
            pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

            processed_signal.put(pred_src2_wav[0])
            end = time.time()

            print("Process time1 = {0}".format(middle - start))
            print("Process time to start model = {0}".format(modeltime -
                                                             middle))
            print("Process time to models = {0}".format(middle2 - modeltime))
            print("Process time till end = {0}".format(end - middle2))
Пример #4
0
def load_data():
    filenames1 = ModelConfig.Audio_filename
    filenames2 = ModelConfig.Noise_filename

    mixed_wav, src1_wav, src2_wav = get_random_wav_batch(
        filenames1, filenames2, ModelConfig.SECOND, ModelConfig.SR)
    mixed_spec = to_spectrogram(mixed_wav)
    mixed_mag = get_magnitude(mixed_spec)

    src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
    src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)

    src1_batch, _ = spec_to_batch(src1_mag)
    src2_batch, _ = spec_to_batch(src2_mag)
    mixed_batch, _ = spec_to_batch(mixed_mag)

    # batch_size, n_frames, n_freq
    return mixed_batch, src1_batch, src2_batch
Пример #5
0
def test_run():

    tf.reset_default_graph()
    model = Model()

    with tf.Session(config=EvalConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())

        data = Datas(EvalConfig.DATA_PATH)
        model.load_state(sess, EvalConfig.CKPT_PATH)

        mixed_wav, src1_wav, src2_wav = data.next_wav(EvalConfig.SECONDS)

        print(src1_wav)

        mixed_spec = to_spectrogram(mixed_wav)
        mixed_mag = get_magnitude(mixed_spec)
        mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
        mixed_phase = get_phase(mixed_spec)

        (pred_src1_mag,
         pred_src2_mag) = sess.run(model(),
                                   feed_dict={model.x_mixed: mixed_batch})

        seq_len = mixed_phase.shape[-1]
        pred_src1_mag = model.batch_to_spec(
            pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
        pred_src2_mag = model.batch_to_spec(
            pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

        # Time-frequency masking
        mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
        # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
        mask_src2 = 1. - mask_src1
        pred_src1_mag = mixed_mag * mask_src1
        pred_src2_mag = mixed_mag * mask_src2

        pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
        pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

        write_wav(mixed_wav[0], '{}/{}'.format(EvalConfig.RESULT_PATH,
                                               'original'))
        write_wav(pred_src1_wav[0], '{}/{}'.format(EvalConfig.RESULT_PATH,
                                                   'music'))
        write_wav(pred_src2_wav[0], '{}/{}'.format(EvalConfig.RESULT_PATH,
                                                   'voice'))
Пример #6
0
def get_drum(filename):

    tf.reset_default_graph()
    model = Model()

    with tf.Session(config=EvalConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())

        data = Datas(RunConfig.DATA_ROOT)
        model.load_state(sess, EvalConfig.CKPT_PATH)

        mixed_wav = data.get_mixture(filename)

        print(mixed_wav)

        mixed_spec = to_spectrogram(mixed_wav)
        mixed_mag = get_magnitude(mixed_spec)
        mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
        mixed_phase = get_phase(mixed_spec)

        (pred_src1_mag,
         pred_src2_mag) = sess.run(model(),
                                   feed_dict={model.x_mixed: mixed_batch})

        seq_len = mixed_phase.shape[-1]
        pred_src1_mag = model.batch_to_spec(
            pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
        pred_src2_mag = model.batch_to_spec(
            pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

        # Time-frequency masking
        mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
        # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
        mask_src2 = 1. - mask_src1
        pred_src2_mag = mixed_mag * mask_src2

        pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

        filename = filename.replace('.wav', '')
        write_wav(pred_src2_wav[0], '{}/{}'.format(RunConfig.RESULT_PATH,
                                                   filename))
Пример #7
0
def eval(data_path=None, result_path=None):
    # Model
    model = Model()
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    with tf.Session(config=EvalConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, EvalConfig.CKPT_PATH)

        writer = tf.summary.FileWriter(EvalConfig.GRAPH_PATH, sess.graph)

        data = Data(data_path) if data_path else Data(EvalConfig.DATA_PATH)
        output_path = result_path if result_path else EvalConfig.RESULT_PATH
        mixed_wav, src1_wav, src2_wav, wavfiles = data.next_wavs(
            EvalConfig.SECONDS, EvalConfig.NUM_EVAL)

        mixed_spec = to_spectrogram(mixed_wav)
        mixed_mag = get_magnitude(mixed_spec)
        mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
        mixed_phase = get_phase(mixed_spec)

        assert (np.all(
            np.equal(model.batch_to_spec(mixed_batch, EvalConfig.NUM_EVAL),
                     padded_mixed_mag)))

        (pred_src1_mag,
         pred_src2_mag) = sess.run(model(),
                                   feed_dict={model.x_mixed: mixed_batch})

        seq_len = mixed_phase.shape[-1]
        pred_src1_mag = model.batch_to_spec(
            pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
        pred_src2_mag = model.batch_to_spec(
            pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

        # Time-frequency masking
        mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
        # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
        mask_src2 = 1. - mask_src1
        pred_src1_mag = mixed_mag * mask_src1
        pred_src2_mag = mixed_mag * mask_src2

        # (magnitude, phase) -> spectrogram -> wav
        if EvalConfig.GRIFFIN_LIM:
            pred_src1_wav = to_wav_mag_only(
                pred_src1_mag,
                init_phase=mixed_phase,
                num_iters=EvalConfig.GRIFFIN_LIM_ITER)
            pred_src2_wav = to_wav_mag_only(
                pred_src2_mag,
                init_phase=mixed_phase,
                num_iters=EvalConfig.GRIFFIN_LIM_ITER)
        else:
            pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
            pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

        # Write the result
        tf.summary.audio('GT_mixed',
                         mixed_wav,
                         ModelConfig.SR,
                         max_outputs=EvalConfig.NUM_EVAL)
        tf.summary.audio('Pred_music',
                         pred_src1_wav,
                         ModelConfig.SR,
                         max_outputs=EvalConfig.NUM_EVAL)
        tf.summary.audio('Pred_vocal',
                         pred_src2_wav,
                         ModelConfig.SR,
                         max_outputs=EvalConfig.NUM_EVAL)

        if EvalConfig.EVAL_METRIC:
            # Compute BSS metrics
            gnsdr, gsir, gsar = bss_eval_global(mixed_wav, src1_wav, src2_wav,
                                                pred_src1_wav, pred_src2_wav)

            # Write the score of BSS metrics
            tf.summary.scalar('GNSDR_music', gnsdr[0])
            tf.summary.scalar('GSIR_music', gsir[0])
            tf.summary.scalar('GSAR_music', gsar[0])
            tf.summary.scalar('GNSDR_vocal', gnsdr[1])
            tf.summary.scalar('GSIR_vocal', gsir[1])
            tf.summary.scalar('GSAR_vocal', gsar[1])

        if EvalConfig.WRITE_RESULT:
            # Write the result
            for i in range(len(wavfiles)):
                name = 'video'
                print output_path
                write_wav(mixed_wav[i],
                          '{}/{}-{}'.format(output_path, name, 'original'))
                write_wav(pred_src1_wav[i],
                          '{}/{}-{}'.format(output_path, name, 'music'))
                write_wav(pred_src2_wav[i],
                          '{}/{}-{}'.format(output_path, name, 'voice'))

        writer.add_summary(sess.run(tf.summary.merge_all()),
                           global_step=global_step.eval())

        writer.close()
Пример #8
0
def train(model, data, Config, lr, eps, num_wav, len_frame):
    len_hop = closest_power_of_two(len_frame / 4)

    # Loss, Optimizer
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')
    loss_fn = model.loss()
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        loss_fn, global_step=global_step)

    summary_op = summaries(model, loss_fn)

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        if not os.path.exists(Config.CKPT_PATH):
            os.makedirs(Config.CKPT_PATH)
        model.load_state(sess, Config.CKPT_PATH)

        writer = tf.summary.FileWriter(
            "{}/{}".format(Config.GRAPH_PATH, "train"), sess.graph)

        print("Starting training...")

        loss = Diff()
        for step in range(global_step.eval(), eps):
            mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(num_wav)

            mixed_spec = to_spectrogram(mixed_wav, len_frame, len_hop)
            mixed_mag = get_magnitude(mixed_spec)

            src1_spec = to_spectrogram(src1_wav, len_frame, len_hop)
            src2_spec = to_spectrogram(src2_wav, len_frame, len_hop)
            src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(
                src2_spec)

            src1_batch, _ = model.spec_to_batch(src1_mag)
            src2_batch, _ = model.spec_to_batch(src2_mag)
            mixed_batch, _ = model.spec_to_batch(mixed_mag)

            l, _, summary = sess.run(
                [loss_fn, optimizer, summary_op],
                feed_dict={
                    model.x_mixed: mixed_batch,
                    model.y_src1: src1_batch,
                    model.y_src2: src2_batch
                })

            loss.update(l)
            print('step-{}\td_loss={:2.2f}\tloss={}'.format(
                step, loss.diff * 100, loss.value))

            writer.add_summary(summary, global_step=step)

            # Save state
            if step % Config.CKPT_STEP == 0:
                print("Saved checkpoint.")
                tf.train.Saver().save(sess,
                                      Config.CKPT_PATH + '/checkpoint',
                                      global_step=step)

        writer.close()
Пример #9
0
def process(input_signal, p):
    # Model
    model = Model()
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')
    CHUNKSIZE = EvalConfig.CHUNK
    stream = p.open(format=pyaudio.paInt32,
                    channels=1,
                    rate=48000,
                    output=True,
                    frames_per_buffer=CHUNKSIZE)

    with tf.Session(config=EvalConfig.session_conf) as sess:
        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, EvalConfig.CKPT_PATH)

        while (True):
            #p_sem.acquire()
            #print("     process before get: %d" % input_signal.qsize())
            mixed_wav = input_signal.get()

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)
            mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
            mixed_phase = get_phase(mixed_spec)

            assert (np.all(
                np.equal(model.batch_to_spec(mixed_batch, EvalConfig.NUM_EVAL),
                         padded_mixed_mag)))

            (pred_src1_mag,
             pred_src2_mag) = sess.run(model(),
                                       feed_dict={model.x_mixed: mixed_batch})

            seq_len = mixed_phase.shape[-1]
            pred_src1_mag = model.batch_to_spec(
                pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
            pred_src2_mag = model.batch_to_spec(
                pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

            # Time-frequency masking
            mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
            mask_src2 = 1. - mask_src1
            pred_src1_mag = mixed_mag * mask_src1
            pred_src2_mag = mixed_mag * mask_src2

            pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

            # free = stream.get_write_available()
            # print("         free1: %d" % free)
            # if (free - CHUNKSIZE) > CHUNKSIZE * 2:
            #      sleep(0.5)
            data = pred_src2_wav[0].astype(np.int32).tostring()
            stream.write(data)
            # free = stream.get_write_available()
            # print("         free2: %d" % free)
            #r_sem.release()
    stream.stop_stream()
    stream.close()
Пример #10
0
def eval():
    # Model
    model = Model()
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    with tf.Session(config=EvalConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, EvalConfig.CKPT_PATH)

        writer = tf.summary.FileWriter(EvalConfig.GRAPH_PATH, sess.graph)

        data = Data(EvalConfig.DATA_PATH, TrainConfig.NOISE_DATA_PATH,
                    TrainConfig.VOICE_DATA_PATH)
        mixed_wav, src1_wav, src2_wav, wavfiles = data.next_wavs_eval(
            EvalConfig.SECONDS, EvalConfig.NUM_EVAL)

        start = time.time()

        mixed_spec = to_spectrogram(mixed_wav)
        mixed_mag = get_magnitude(mixed_spec)
        mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
        mixed_phase = get_phase(mixed_spec)

        assert (np.all(
            np.equal(model.batch_to_spec(mixed_batch, EvalConfig.NUM_EVAL),
                     padded_mixed_mag)))

        (pred_src1_mag,
         pred_src2_mag) = sess.run(model(),
                                   feed_dict={model.x_mixed: mixed_batch})

        seq_len = mixed_phase.shape[-1]
        pred_src1_mag = model.batch_to_spec(
            pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
        pred_src2_mag = model.batch_to_spec(
            pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

        # Time-frequency masking
        mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
        # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
        mask_src2 = 1. - mask_src1
        pred_src1_mag = mixed_mag * mask_src1
        pred_src2_mag = mixed_mag * mask_src2

        # (magnitude, phase) -> spectrogram -> wav
        if EvalConfig.GRIFFIN_LIM:
            pred_src1_wav = to_wav_mag_only(
                pred_src1_mag,
                init_phase=mixed_phase,
                num_iters=EvalConfig.GRIFFIN_LIM_ITER)
            pred_src2_wav = to_wav_mag_only(
                pred_src2_mag,
                init_phase=mixed_phase,
                num_iters=EvalConfig.GRIFFIN_LIM_ITER)
        else:
            pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
            pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

        end = time.time()
        print("Time elapsed: {0}".format(end - start))

        # Write the result
        tf.summary.audio('GT_mixed',
                         mixed_wav,
                         ModelConfig.SR,
                         max_outputs=EvalConfig.NUM_EVAL)
        tf.summary.audio('Pred_music',
                         pred_src1_wav,
                         ModelConfig.SR,
                         max_outputs=EvalConfig.NUM_EVAL)
        tf.summary.audio('Pred_vocal',
                         pred_src2_wav,
                         ModelConfig.SR,
                         max_outputs=EvalConfig.NUM_EVAL)

        # Write the result
        for i in range(len(wavfiles)):
            name = wavfiles[i].replace('/', '-').replace('.wav', '')
            write_wav(
                mixed_wav[i], '{}/{}-{}'.format(EvalConfig.RESULT_PATH, name,
                                                'original'))
            write_wav(
                pred_src1_wav[i], '{}/{}-{}'.format(EvalConfig.RESULT_PATH,
                                                    name, 'background'))
            write_wav(pred_src2_wav[i],
                      '{}/{}-{}'.format(EvalConfig.RESULT_PATH, name, 'voice'))

        writer.add_summary(sess.run(tf.summary.merge_all()),
                           global_step=global_step.eval())

        writer.close()
Пример #11
0
def train():
    tf.reset_default_graph()
    # Model
    model = Model()

    # Loss, Optimizer
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')
    loss_fn = model.loss()
    optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(
        loss_fn, global_step=global_step)

    # Summaries
    summary_op = summaries(model, loss_fn)

    with tf.Session(config=TrainConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, TrainConfig.CKPT_PATH)

        writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)

        # Input source
        data = Datas(TrainConfig.DATA_PATH)

        for step in range(
                global_step.eval(),
                TrainConfig.FINAL_STEP):  # changed xrange to range for py3
            mixed_wav, src1_wav, src2_wav = data.next_wav(TrainConfig.SECONDS)

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)

            src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(
                src2_wav)
            src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(
                src2_spec)

            src1_batch, _ = model.spec_to_batch(src1_mag)
            src2_batch, _ = model.spec_to_batch(src2_mag)
            mixed_batch, _ = model.spec_to_batch(mixed_mag)

            l, _, summary = sess.run(
                [loss_fn, optimizer, summary_op],
                feed_dict={
                    model.x_mixed: mixed_batch,
                    model.y_src1: src1_batch,
                    model.y_src2: src2_batch
                })

            print('For iteration {}\\t loss={}'.format(step, l))

            writer.add_summary(summary, global_step=step)

            # Save state
            if step % TrainConfig.CKPT_STEP == 0:
                tf.train.Saver().save(sess,
                                      TrainConfig.CKPT_PATH + '/checkpoint',
                                      global_step=step)

        writer.close()
Пример #12
0
def eval(n):
    overall_gnsdr, overall_gsir, overall_gsar = [], [], []
    for i in range(n):
        with tf.Graph().as_default():
            # Model
            model = Model(ModelConfig.HIDDEN_LAYERS, ModelConfig.HIDDEN_UNITS)
            global_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name='global_step')

            with tf.Session(config=EvalConfig.session_conf) as sess:

                # Initialized, Load state
                sess.run(tf.global_variables_initializer())
                model.load_state(sess, EvalConfig.CKPT_PATH)

                print('num trainable parameters: %s' % (np.sum([
                    np.prod(v.get_shape().as_list())
                    for v in tf.trainable_variables()
                ])))

                writer = tf.summary.FileWriter(EvalConfig.GRAPH_PATH,
                                               sess.graph)

                data = Data(EvalConfig.DATA_PATH)
                mixed_wav, src1_wav, src2_wav, wavfiles = data.next_wavs(
                    EvalConfig.SECONDS, EvalConfig.NUM_EVAL)

                mixed_spec = to_spectrogram(mixed_wav)
                mixed_mag = get_magnitude(mixed_spec)
                mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
                mixed_phase = get_phase(mixed_spec)

                assert (np.all(
                    np.equal(
                        model.batch_to_spec(mixed_batch, EvalConfig.NUM_EVAL),
                        padded_mixed_mag)))

                (pred_src1_mag, pred_src2_mag) = sess.run(
                    model(), feed_dict={model.x_mixed: mixed_batch})

                seq_len = mixed_phase.shape[-1]
                pred_src1_mag = model.batch_to_spec(
                    pred_src1_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]
                pred_src2_mag = model.batch_to_spec(
                    pred_src2_mag, EvalConfig.NUM_EVAL)[:, :, :seq_len]

                # Time-frequency masking
                mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
                # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
                mask_src2 = 1. - mask_src1
                pred_src1_mag = mixed_mag * mask_src1
                pred_src2_mag = mixed_mag * mask_src2

                # (magnitude, phase) -> spectrogram -> wav
                if EvalConfig.GRIFFIN_LIM:
                    pred_src1_wav = to_wav_mag_only(
                        pred_src1_mag,
                        init_phase=mixed_phase,
                        num_iters=EvalConfig.GRIFFIN_LIM_ITER)
                    pred_src2_wav = to_wav_mag_only(
                        pred_src2_mag,
                        init_phase=mixed_phase,
                        num_iters=EvalConfig.GRIFFIN_LIM_ITER)
                else:
                    pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
                    pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

                # Write the result
                tf.summary.audio('GT_mixed',
                                 mixed_wav,
                                 ModelConfig.SR,
                                 max_outputs=EvalConfig.NUM_EVAL)
                tf.summary.audio('Pred_music',
                                 pred_src1_wav,
                                 ModelConfig.SR,
                                 max_outputs=EvalConfig.NUM_EVAL)
                tf.summary.audio('Pred_vocal',
                                 pred_src2_wav,
                                 ModelConfig.SR,
                                 max_outputs=EvalConfig.NUM_EVAL)

                if EvalConfig.EVAL_METRIC:
                    # Compute BSS metrics
                    gnsdr, gsir, gsar = bss_eval_global(
                        mixed_wav, src1_wav, src2_wav, pred_src1_wav,
                        pred_src2_wav)

                    # Write the score of BSS metrics
                    tf.summary.scalar('GNSDR_music', gnsdr[0])
                    tf.summary.scalar('GSIR_music', gsir[0])
                    tf.summary.scalar('GSAR_music', gsar[0])
                    tf.summary.scalar('GNSDR_vocal', gnsdr[1])
                    tf.summary.scalar('GSIR_vocal', gsir[1])
                    tf.summary.scalar('GSAR_vocal', gsar[1])
                    print('GNSDR: ', gnsdr)
                    print('GSIR: ', gsir)
                    print('GSAR: ', gsar)

                overall_gnsdr.append(gnsdr)
                overall_gsir.append(gsir)
                overall_gsar.append(gsar)

                if EvalConfig.WRITE_RESULT:
                    # Write the result
                    for i in range(len(wavfiles)):
                        name = wavfiles[i].replace('/',
                                                   '-').replace('.wav', '')
                        write_wav(
                            mixed_wav[i],
                            '{}/{}-{}'.format(EvalConfig.RESULT_PATH, name,
                                              'original'))
                        write_wav(
                            pred_src1_wav[i],
                            '{}/{}-{}'.format(EvalConfig.RESULT_PATH, name,
                                              'music'))
                        write_wav(
                            pred_src2_wav[i],
                            '{}/{}-{}'.format(EvalConfig.RESULT_PATH, name,
                                              'voice'))

                writer.add_summary(sess.run(tf.summary.merge_all()),
                                   global_step=global_step.eval())

                writer.close()

    if n > 1:
        overall_gnsdr = np.array(overall_gnsdr)
        overall_gsir = np.array(overall_gsir)
        overall_gsar = np.array(overall_gsar)
        overall_gnsdr = np.mean(overall_gnsdr, axis=0)
        overall_gsir = np.mean(overall_gsir, axis=0)
        overall_gsar = np.mean(overall_gsar, axis=0)

        print('OVERALL GNSDR: ', overall_gnsdr)
        print('OVERALL GSIR: ', overall_gsir)
        print('OVERALL GSAR: ', overall_gsar)
Пример #13
0
def separate(filename, channel):
    with tf.Graph().as_default():
        # Model
        model = Model(ModelConfig.HIDDEN_LAYERS, ModelConfig.HIDDEN_UNITS)
        global_step = tf.Variable(0,
                                  dtype=tf.int32,
                                  trainable=False,
                                  name='global_step')

        total_samples, origin_samples, samples = decode_input(filename)
        channels = origin_samples.shape[0]
        with tf.Session(config=EvalConfig.session_conf) as sess:

            # Initialized, Load state
            sess.run(tf.global_variables_initializer())
            model.load_state(sess, EvalConfig.CKPT_PATH)

            mixed_wav, src1_wav, src2_wav = samples, samples, samples

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)
            mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
            mixed_phase = get_phase(mixed_spec)

            (pred_src1_mag,
             pred_src2_mag) = sess.run(model(),
                                       feed_dict={model.x_mixed: mixed_batch})

            seq_len = mixed_phase.shape[-1]
            pred_src1_mag = model.batch_to_spec(pred_src1_mag,
                                                1)[:, :, :seq_len]
            pred_src2_mag = model.batch_to_spec(pred_src2_mag,
                                                1)[:, :, :seq_len]

            # Time-frequency masking
            mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
            # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
            mask_src2 = 1. - mask_src1
            pred_src1_mag = mixed_mag * mask_src1
            pred_src2_mag = mixed_mag * mask_src2

            # (magnitude, phase) -> spectrogram -> wav
            if EvalConfig.GRIFFIN_LIM:
                pred_src1_wav = to_wav_mag_only(
                    pred_src1_mag,
                    init_phase=mixed_phase,
                    num_iters=EvalConfig.GRIFFIN_LIM_ITER)
                pred_src2_wav = to_wav_mag_only(
                    pred_src2_mag,
                    init_phase=mixed_phase,
                    num_iters=EvalConfig.GRIFFIN_LIM_ITER)
            else:
                pred_src1_wav = to_wav(pred_src1_mag, mixed_phase)
                pred_src2_wav = to_wav(pred_src2_mag, mixed_phase)

            def stack(data):
                size = data.shape[0] // channels
                elements = []
                for i in range(channels):
                    elements.append(data[size * i:size * (i + 1)])
                return np.dstack(elements)[0]

            music_data = pred_src1_wav
            voice_data = pred_src2_wav

            if channel >= 0:

                def filter_samples(data):
                    for i in range(origin_samples.shape[0]):
                        if i != channel:
                            data[i, :] = origin_samples[i, 0:data.shape[1]]
                    return data

                music_data = filter_samples(music_data)
                voice_data = filter_samples(voice_data)

            music_wav = np.dstack(music_data)[0]
            voice_wav = np.dstack(voice_data)[0]
            return music_wav, voice_wav
    return None
Пример #14
0
def train():
    # Model
    model = Model(ModelConfig.HIDDEN_LAYERS, ModelConfig.HIDDEN_UNITS)

    # Loss, Optimizer
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')
    loss_fn = model.loss()
    optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(
        loss_fn, global_step=global_step)
    #optimizer = tf.train.GradientDescentOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)

    model.gnsdr_music = tf.placeholder(dtype=tf.float32,
                                       shape=(),
                                       name='gnsdr_music')
    model.gsir_music = tf.placeholder(dtype=tf.float32,
                                      shape=(),
                                      name='gsir_music')
    model.gsar_music = tf.placeholder(dtype=tf.float32,
                                      shape=(),
                                      name='gsar_music')

    model.gnsdr_vocal = tf.placeholder(dtype=tf.float32,
                                       shape=(),
                                       name='gnsdr_vocal')
    model.gsir_vocal = tf.placeholder(dtype=tf.float32,
                                      shape=(),
                                      name='gsir_vocal')
    model.gsar_vocal = tf.placeholder(dtype=tf.float32,
                                      shape=(),
                                      name='gsar_vocal')

    # Summaries
    summary_ops = summaries(model, loss_fn)

    with tf.Session(config=TrainConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, TrainConfig.CKPT_PATH)

        print('num trainable parameters: %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))

        writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)

        # Input source
        data = Data(TrainConfig.DATA_PATH)
        eval_data = Data(EvalConfig.DATA_PATH)

        loss = Diff()
        gnsdr, gsir, gsar = np.array([0, 0]), np.array([0,
                                                        0]), np.array([0, 0])
        intial_global_step = global_step.eval()
        for step in range(intial_global_step, TrainConfig.FINAL_STEP):
            start_time = time.time()

            bss_metric = step % 20 == 0 or step == intial_global_step
            bss_eval = ''
            if bss_metric:
                gnsdr, gsir, gsar = eval(model, eval_data, sess)
                bss_eval = 'GNSDR: {} GSIR: {} GSAR: {}'.format(
                    gnsdr, gsir, gsar)

            mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(
                TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE)

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)

            src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(
                src2_wav)
            src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(
                src2_spec)

            src1_batch, _ = model.spec_to_batch(src1_mag)
            src2_batch, _ = model.spec_to_batch(src2_mag)
            mixed_batch, _ = model.spec_to_batch(mixed_mag)

            l, _, summary = sess.run(
                [loss_fn, optimizer, summary_ops],
                feed_dict={
                    model.x_mixed: mixed_batch,
                    model.y_src1: src1_batch,
                    model.y_src2: src2_batch,
                    model.gnsdr_music: gnsdr[0],
                    model.gsir_music: gsir[0],
                    model.gsar_music: gsar[0],
                    model.gnsdr_vocal: gnsdr[1],
                    model.gsir_vocal: gsir[1],
                    model.gsar_vocal: gsar[1]
                })
            loss.update(l)
            writer.add_summary(summary, global_step=step)

            # Save state
            if step % TrainConfig.CKPT_STEP == 0:
                tf.train.Saver().save(sess,
                                      TrainConfig.CKPT_PATH + '/checkpoint',
                                      global_step=step)

            elapsed_time = time.time() - start_time
            print(
                'step-{}\ttime={:2.2f}\td_loss={:2.2f}\tloss={:2.3f}\tbss_eval: {}'
                .format(step, elapsed_time, loss.diff * 100, loss.value,
                        bss_eval))

        writer.close()
Пример #15
0
from config import TrainConfig, ModelConfig
from data import Data

from librosa import amplitude_to_db, stft
from librosa.display import specshow
from preprocess import to_spectrogram, get_magnitude
from pylab import savefig

import matplotlib.pyplot as plt
import numpy as np

data = Data(TrainConfig.DATA_PATH)

mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS,
                                                  TrainConfig.NUM_WAVFILE)
mixed_spec = to_spectrogram(mixed_wav)
mixed_mag = get_magnitude(mixed_spec)
src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)

sr = ModelConfig.SR
y = src1_wav[0]


def plot_wav_as_spec(wav, sr=ModelConfig.SR, s=0.5, path='foo.png'):
    """Plots a spectrogram

    Will save as foo.png in script directory

    Arguments:
        wav {array} -- audio data
Пример #16
0
def eval(model, data, sr, len_frame, num_wav, glim, glim_iter, ckpt_path,
    graph_path, result_path):
    len_hop = closest_power_of_two(len_frame / 4)
    global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')

    with tf.Session() as sess:
        if not os.path.exists(result_path):
            os.makedirs(result_path)

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, ckpt_path)

        writer = tf.summary.FileWriter("{}/{}".format(graph_path, "eval"), sess.graph)

        mixed_wav, src1_wav, src2_wav, med_names = data.next_wavs(num_wav)

        mixed_spec = to_spectrogram(mixed_wav, len_frame, len_hop)
        mixed_mag = get_magnitude(mixed_spec)
        mixed_batch, padded_mixed_mag = model.spec_to_batch(mixed_mag)
        mixed_phase = get_phase(mixed_spec)

        assert (np.all(np.equal(model.batch_to_spec(mixed_batch, num_wav),
            padded_mixed_mag)))

        (pred_src1_mag, pred_src2_mag) = sess.run(model(),
            feed_dict={model.x_mixed: mixed_batch})

        seq_len = mixed_phase.shape[-1]
        pred_src1_mag = model.batch_to_spec(pred_src1_mag, num_wav)[:, :, :seq_len]
        pred_src2_mag = model.batch_to_spec(pred_src2_mag, num_wav)[:, :, :seq_len]

        # Time-frequency masking
        mask_src1 = soft_time_freq_mask(pred_src1_mag, pred_src2_mag)
        # mask_src1 = hard_time_freq_mask(pred_src1_mag, pred_src2_mag)
        mask_src2 = 1. - mask_src1
        pred_src1_mag = mixed_mag * mask_src1
        pred_src2_mag = mixed_mag * mask_src2

        # (magnitude, phase) -> spectrogram -> wav
        if glim:
            pred_src1_wav = to_wav_mag_only(pred_src1_mag, mixed_phase, len_frame,
                len_hop, num_iters=glim_iter)
            pred_src2_wav = to_wav_mag_only(pred_src2_mag, mixed_phase, len_frame,
                len_hop, num_iters=glim_iter)
        else:
            pred_src1_wav = to_wav(pred_src1_mag, mixed_phase, len_hop)
            pred_src2_wav = to_wav(pred_src2_mag, mixed_phase, len_hop)

        # Write the result
        tf.summary.audio('GT_mixed', mixed_wav, sr, max_outputs=num_wav)
        tf.summary.audio('Pred_music', pred_src1_wav, sr, max_outputs=num_wav)
        tf.summary.audio('Pred_vocal', pred_src2_wav, sr, max_outputs=num_wav)

        # Compute BSS metrics
        gnsdr, gsir, gsar = bss_eval_global(mixed_wav, src1_wav, src2_wav, pred_src1_wav,
            pred_src2_wav, num_wav)

        # Write the score of BSS metrics
        tf.summary.scalar('GNSDR_music', gnsdr[0])
        tf.summary.scalar('GSIR_music', gsir[0])
        tf.summary.scalar('GSAR_music', gsar[0])
        tf.summary.scalar('GNSDR_vocal', gnsdr[1])
        tf.summary.scalar('GSIR_vocal', gsir[1])
        tf.summary.scalar('GSAR_vocal', gsar[1])

        # Write the result
        for i in range(len(med_names)):
            write_wav(mixed_wav[i], '{}/{}-{}'.format(result_path, med_names[i],
                'all_stems_mixed'), sr)
            write_wav(pred_src1_wav[i], '{}/{}-{}'.format(result_path, med_names[i],
                'target_instrument'), sr)
            write_wav(pred_src2_wav[i], '{}/{}-{}'.format(result_path, med_names[i],
                'other_stems_mixed'), sr)

        writer.add_summary(sess.run(tf.summary.merge_all()), global_step=global_step.eval())

        writer.close()
Пример #17
0
def train():
    # Model
    model = Model()

    # Loss, Optimizer
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')
    loss_fn = model.loss()
    optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(
        loss_fn, global_step=global_step)

    # Summaries
    summary_op = summaries(model, loss_fn)

    with tf.Session(config=TrainConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        model.load_state(sess, TrainConfig.CKPT_PATH)

        writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)

        # Input source
        data = Data(TrainConfig.DATA_PATH, TrainConfig.NOISE_DATA_PATH,
                    TrainConfig.VOICE_DATA_PATH)

        loss = Diff()
        for step in range(
                global_step.eval(),
                TrainConfig.FINAL_STEP):  # changed xrange to range for py3
            # retry chunk retrieval until it gets a good frame
            while (True):
                try:
                    mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(
                        TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE)
                    break
                except:
                    print("Not a whole frame")

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)

            src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(
                src2_wav)
            src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(
                src2_spec)

            src1_batch, _ = model.spec_to_batch(src1_mag)
            src2_batch, _ = model.spec_to_batch(src2_mag)
            mixed_batch, _ = model.spec_to_batch(mixed_mag)

            l, _, summary = sess.run(
                [loss_fn, optimizer, summary_op],
                feed_dict={
                    model.x_mixed: mixed_batch,
                    model.y_src1: src1_batch,
                    model.y_src2: src2_batch
                })

            loss.update(l)
            print('step-{}\td_loss={:2.2f}\tloss={}'.format(
                step, loss.diff * 100, loss.value))

            writer.add_summary(summary, global_step=step)

            # Save state
            if step % TrainConfig.CKPT_STEP == 0:
                tf.train.Saver().save(sess,
                                      TrainConfig.CKPT_PATH + '/checkpoint',
                                      global_step=step)

        writer.close()
def train():

    dsd_train, dsd_test = GetData.getDSDFilelist("DSD100.xml")

    dataset = dict()
    dataset[
        "train_sup"] = dsd_train  # 50 training tracks from DSD100 as supervised dataset
    dataset["valid"] = dsd_test[:25]
    dataset["test"] = dsd_test[25:]

    with open('dataset.pkl', 'wb') as file:
        pickle.dump(dataset, file)
        print("Created dataset structure")
    # Model
    model = Model()

    # Loss, Optimizer
    #global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
    loss_fn = model.loss()
    lr = ((CONFIG_MAP['flat-R-VAE'].hparams.learning_rate -
           CONFIG_MAP['flat-R-VAE'].hparams.min_learning_rate) *
          tf.pow(CONFIG_MAP['flat-R-VAE'].hparams.decay_rate,
                 tf.to_float(global_step)) +
          CONFIG_MAP['flat-R-VAE'].hparams.min_learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        loss_fn, global_step=model.global_step)

    # Summaries
    summary_op = summaries(model, loss_fn)

    with tf.Session(config=TrainConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())
        #model.load_state(sess, TrainConfig.CKPT_PATH)

        writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)

        # Input source

        btch_size = CONFIG_MAP['flat-R-VAE'].hparams.batch_size
        loss = Diff()
        i = 0
        for step in range(
                model.global_step.eval(),
                TrainConfig.FINAL_STEP):  # changed xrange to range for py3
            if (i > 50):
                i = 0
            batch_ = dsd_train[i:i + btch_size]
            i = i + btch_size
            mixed_wav, drums_wav = get_random_wav(batch_, TrainConfig.SECONDS,
                                                  ModelConfig.SR)

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)

            drums_spec = to_spectrogram(drums_wav)
            drums_mag = get_magnitude(drums_spec)

            mixed_batch, _ = model.spec_to_batch(mixed_mag)
            drums_batch, _ = model.spec_to_batch(drums_mag)

            l, _, summary = sess.run(
                [loss_fn, optimizer, summary_op],
                feed_dict={
                    model.x_mixed: mixed_batch,
                    model.x_drums: drums_batch,
                    model.y_drums: drums_batch
                })

            loss.update(l)
            print('step-{}\td_loss={:2.2f}\tloss={}'.format(
                step, loss.diff * 100, loss.value))

            writer.add_summary(summary, global_step=step)

            # Save state
            if step % TrainConfig.CKPT_STEP == 0:
                tf.train.Saver().save(sess,
                                      TrainConfig.CKPT_PATH + '/checkpoint',
                                      global_step=step)

        writer.close()