Esempio n. 1
0
def model_train(freq_weighted):
    logs_path = "tensorboard/" + strftime("%Y_%m_%d_%H_%M_%S", gmtime()) + model_name 
    
    with tf.Graph().as_default():
        
        train_inputs, train_targets = prepare_data(True)

        model = SeparationModel(freq_weighted=False)  # don't use freq_weighted for now
        model.run_on_batch(train_inputs, train_targets)
        
        init = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())
        saver = tf.train.Saver()

        with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state('checkpoints/')
            
            if ckpt:
                print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
                session.run(tf.initialize_local_variables())
                saver.restore(session, ckpt.model_checkpoint_path)
            else:
                session.run(init)

            train_writer = tf.summary.FileWriter(logs_path + '/train', session.graph)
            global_start = time.time()
            
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            print('num trainable parameters: %s' % (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
            step_ii = 0

            try:
                
                while not coord.should_stop():
                    start = time.time()
                    step_ii += 1
                    
                    output, batch_cost, masked_loss, summary, optimizer = session.run([model.output, model.loss, model.masked_loss, model.merged_summary_op, model.optimizer])
                    
                    # total_train_cost += batch_cost * curr_batch_size
                    train_writer.add_summary(summary, step_ii)

                    duration = time.time() - start

                    if step_ii % 10 == 0:
                        print('Step %d: loss = %.5f masked_loss = %.5f (%.3f sec)' % (step_ii, batch_cost, masked_loss, duration))

                    if step_ii % 500 == 0:
                        checkpoint_name = 'checkpoints/' + model_name
                        saver.save(session, checkpoint_name, global_step=model.global_step)

            except tf.errors.OutOfRangeError:
                print('Done Training for %d epochs, %d steps' % (Config.num_epochs, step_ii))
            finally:
                coord.request_stop()

            coord.join(threads)
Esempio n. 2
0
def model_test():
    with tf.Graph().as_default():
        train_inputs, train_targets = prepare_data(False)

        model = SeparationModel(
            freq_weighted=False)  # don't use freq_weighted for now

        model.run_on_batch(train_inputs, train_targets)
        print(train_inputs.get_shape())

        init = tf.group(tf.initialize_all_variables(),
                        tf.initialize_local_variables())
        saver = tf.train.Saver()

        with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state('checkpoints/')
            if ckpt:
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(session, ckpt.model_checkpoint_path)
                session.run(tf.initialize_local_variables())
            else:
                session.run(init)
            global_start = time.time()

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            soft_masked_results = []

            print('num trainable parameters: %s' % (np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))

            try:
                step_ii = 0
                while not coord.should_stop():
                    start = time.time()

                    soft_masked_output, batch_cost, masked_loss, summary, target, mixed_spec = session.run(
                        [
                            model.soft_masked_output, model.loss,
                            model.masked_loss, model.merged_summary_op,
                            model.target, model.input
                        ])

                    step_ii += 1
                    duration = time.time() - start

                    print(
                        'Step %d: loss = %.5f masked_loss = %.5f (%.3f sec)' %
                        (step_ii, batch_cost, masked_loss, duration))

                    soft_song_masked, soft_voice_masked = tf.split(
                        soft_masked_output,
                        [Config.num_freq_bins, Config.num_freq_bins],
                        axis=1)
                    # soft_song_masked *= stats[1][0]
                    # soft_song_masked += stats[0][0]
                    # soft_voice_masked *= stats[1][1]
                    # soft_voice_masked += stats[0][1]
                    song_target, voice_target = tf.split(
                        target, [Config.num_freq_bins, Config.num_freq_bins],
                        axis=1)
                    # song_target *= stats[1][0]
                    # song_target += stats[0][0]
                    # voice_target *= stats[1][1]
                    # voice_target += stats[0][1]

                    mixed_spec = mixed_spec[:, :, 1]
                    # mixed_spec *= stats[1][2]
                    # mixed_spec += stats[0][2]

                    result_wav_dir = 'data/results'

                    mixed_audio = create_audio_from_spectrogram(mixed_spec)
                    writeWav(
                        os.path.join(result_wav_dir,
                                     'mixed%d.wav' % (step_ii)),
                        Config.sample_rate, mixed_audio)

                    soft_song_masked_audio = create_audio_from_spectrogram(
                        soft_song_masked)
                    soft_voice_masked_audio = create_audio_from_spectrogram(
                        soft_voice_masked)

                    writeWav(
                        os.path.join(result_wav_dir,
                                     'soft_song_masked%d.wav' % (step_ii)),
                        Config.sample_rate, soft_song_masked_audio)
                    writeWav(
                        os.path.join(result_wav_dir,
                                     'soft_voice_masked%d.wav' % (step_ii)),
                        Config.sample_rate, soft_voice_masked_audio)

                    song_target_audio = create_audio_from_spectrogram(
                        song_target)
                    voice_target_audio = create_audio_from_spectrogram(
                        voice_target)

                    writeWav(
                        os.path.join(result_wav_dir,
                                     'song_target%d.wav' % (step_ii)),
                        Config.sample_rate, song_target_audio)
                    writeWav(
                        os.path.join(result_wav_dir,
                                     'voice_target%d.wav' % (step_ii)),
                        Config.sample_rate, voice_target_audio)

                    # soft_sdr, soft_sir, soft_sar, _ = bss_eval_sources(np.array([song_target_audio, voice_target_audio]), np.array([soft_song_masked_audio, soft_voice_masked_audio]), False)
                    soft_gnsdr, soft_gsir, soft_gsar = bss_eval_global(
                        mixed_audio, song_target_audio, voice_target_audio,
                        soft_song_masked_audio, soft_voice_masked_audio)

                    # masked_results.append([soft_sdr[0], soft_sdr[1], soft_sir[0], soft_sir[1], soft_sar[0],soft_sar[1]])
                    print(soft_gnsdr[0], soft_gnsdr[1], soft_gsir[0],
                          soft_gsir[1], soft_gsar[0], soft_gsar[1])
                    soft_masked_results.append([
                        soft_gnsdr[0], soft_gnsdr[1], soft_gsir[0],
                        soft_gsir[1], soft_gsar[0], soft_gsar[1]
                    ])

            except tf.errors.OutOfRangeError:
                soft_masked_results = np.asarray(soft_masked_results)
                print(np.mean(soft_masked_results, axis=0))
            finally:
                coord.request_stop()

            coord.join(threads)