示例#1
0
文件: run.py 项目: blumix/drummer
def model_train(freq_weighted):
    logs_path = "tensorboard/" + strftime("%Y_%m_%d_%H_%M_%S", gmtime()) + model_name 
    
    with tf.Graph().as_default():
        
        train_inputs, train_targets = prepare_data(True)

        model = SeparationModel(freq_weighted=False)  # don't use freq_weighted for now
        model.run_on_batch(train_inputs, train_targets)
        
        init = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())
        saver = tf.train.Saver()

        with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state('checkpoints/')
            
            if ckpt:
                print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
                session.run(tf.initialize_local_variables())
                saver.restore(session, ckpt.model_checkpoint_path)
            else:
                session.run(init)

            train_writer = tf.summary.FileWriter(logs_path + '/train', session.graph)
            global_start = time.time()
            
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            print('num trainable parameters: %s' % (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
            step_ii = 0

            try:
                
                while not coord.should_stop():
                    start = time.time()
                    step_ii += 1
                    
                    output, batch_cost, masked_loss, summary, optimizer = session.run([model.output, model.loss, model.masked_loss, model.merged_summary_op, model.optimizer])
                    
                    # total_train_cost += batch_cost * curr_batch_size
                    train_writer.add_summary(summary, step_ii)

                    duration = time.time() - start

                    if step_ii % 10 == 0:
                        print('Step %d: loss = %.5f masked_loss = %.5f (%.3f sec)' % (step_ii, batch_cost, masked_loss, duration))

                    if step_ii % 500 == 0:
                        checkpoint_name = 'checkpoints/' + model_name
                        saver.save(session, checkpoint_name, global_step=model.global_step)

            except tf.errors.OutOfRangeError:
                print('Done Training for %d epochs, %d steps' % (Config.num_epochs, step_ii))
            finally:
                coord.request_stop()

            coord.join(threads)
示例#2
0
def model_train(freq_weighted):
    logs_path = "tensorboard/" + strftime("%Y_%m_%d_%H_%M_%S", gmtime())

    TESTING_MODE = True

    data = h5py.File('%sdata%d' % (DIR, 0))['data'].value
    np.append(data, h5py.File('%sdata%d' % (DIR, 1))['data'].value)

    combined, clean, noise = zip(data)
    combined = combined[0]
    clean = clean[0]
    noise = noise[0]

    target = np.concatenate((clean, noise), axis=2)

    num_data = len(combined)
    random.seed(1)
    dev_ix = set(random.sample(xrange(num_data), num_data / 5))

    train_input = [s for i, s in enumerate(combined) if i not in dev_ix]
    train_target = [s for i, s in enumerate(target) if i not in dev_ix]
    dev_input = [s for i, s in enumerate(combined) if i in dev_ix]
    dev_target = [s for i, s in enumerate(target) if i in dev_ix]

    train_input_batch, train_target_batch = create_batch(
        train_input, train_target, Config.batch_size)
    dev_input_batch, dev_target_batch = create_batch(dev_input, dev_target,
                                                     Config.batch_size)

    num_data = np.sum(len(batch) for batch in train_input_batch)
    num_batches_per_epoch = int(math.ceil(num_data / Config.batch_size))
    num_dev_data = np.sum(len(batch) for batch in dev_input_batch)
    num_dev_batches_per_epoch = int(math.ceil(num_dev_data /
                                              Config.batch_size))

    with tf.Graph().as_default():
        model = SeparationModel(freq_weighted=freq_weighted)
        init = tf.global_variables_initializer()

        saver = tf.train.Saver(tf.trainable_variables())

        with tf.Session() as session:
            session.run(init)

            # if args.load_from_file is not None:
            #     new_saver = tf.train.import_meta_graph('%s.meta' % args.load_from_file, clear_devices=True)
            #     new_saver.restore(session, args.load_from_file)

            train_writer = tf.summary.FileWriter(logs_path + '/train',
                                                 session.graph)

            global_start = time.time()

            step_ii = 0

            for curr_epoch in range(Config.num_epochs):
                total_train_cost = 0
                total_train_examples = 0

                start = time.time()

                for batch in random.sample(range(num_batches_per_epoch),
                                           num_batches_per_epoch):
                    cur_batch_size = len(train_target_batch[batch])
                    total_train_examples += cur_batch_size

                    _, batch_cost, summary = model.train_on_batch(
                        session,
                        train_input_batch[batch],
                        train_target_batch[batch],
                        train=True)

                    total_train_cost += batch_cost * cur_batch_size
                    train_writer.add_summary(summary, step_ii)

                    step_ii += 1

                train_cost = total_train_cost / total_train_examples

                num_dev_batches = len(dev_target_batch)
                total_batch_cost = 0
                total_batch_examples = 0

                # val_batch_cost, _ = model.train_on_batch(session, dev_input_batch[0], dev_target_batch[0], train=False)
                for batch in random.sample(range(num_dev_batches_per_epoch),
                                           num_dev_batches_per_epoch):
                    cur_batch_size = len(dev_target_batch[batch])
                    total_batch_examples += cur_batch_size

                    _, _val_batch_cost, _ = model.train_on_batch(
                        session,
                        dev_input_batch[batch],
                        dev_target_batch[batch],
                        train=False)

                    total_batch_cost += cur_batch_size * _val_batch_cost

                val_batch_cost = None
                try:
                    val_batch_cost = total_batch_cost / total_batch_examples
                except ZeroDivisionError:
                    val_batch_cost = 0

                log = "Epoch {}/{}, train_cost = {:.3f}, val_cost = {:.3f}, time = {:.3f}"
                print(
                    log.format(curr_epoch + 1, Config.num_epochs, train_cost,
                               val_batch_cost,
                               time.time() - start))

                # if args.print_every is not None and (curr_epoch + 1) % args.print_every == 0:
                #     batch_ii = 0
                #     model.print_results(train_feature_minibatches[batch_ii], train_labels_minibatches[batch_ii])

                if (curr_epoch + 1) % 10 == 0:
                    checkpoint_name = 'checkpoints/%dlayer_%flr_model' % (
                        Config.num_layers, Config.lr)
                    if freq_weighted:
                        checkpoint_name = checkpoint_name + '_freq_weighted'
                    saver.save(session,
                               checkpoint_name,
                               global_step=curr_epoch + 1)
示例#3
0
def model_batch_test():
    test_batch = h5py.File('%stest_batch' % (DIR))
    data = test_batch['data'].value

    with open('%stest_settings.pkl' % (DIR), 'rb') as f:
        settings = pickle.load(f)

    # print(settings[:2])

    combined, clean, noise = zip(data)
    combined = combined[0]
    clean = clean[0]
    noise = noise[0]
    target = np.concatenate((clean, noise), axis=2)

    # test_rate, test_audio = wavfile.read('data/test_combined/combined.wav')
    # test_spec = stft.spectrogram(test_audio)

    combined_batch, target_batch = create_batch(combined, target, 50)

    original_combined_batch = [
        copy.deepcopy(batch) for batch in combined_batch
    ]

    with tf.Graph().as_default():
        model = SeparationModel()
        saver = tf.train.Saver(tf.trainable_variables())

        with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state('checkpoints/')
            if ckpt:
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(session, ckpt.model_checkpoint_path)
            else:
                print("Created model with fresh parameters.")
                session.run(tf.initialize_all_variables())

            curr_mask_array = []
            prev_mask_array = None
            diff = float('inf')
            iters = 0

            while True:
                iters += 1
                output, _, _ = model.train_on_batch(session,
                                                    combined_batch[0],
                                                    target_batch[0],
                                                    train=False)

                num_freq_bin = output.shape[2] / 2
                clean_outputs = output[:, :, :num_freq_bin]
                noise_outputs = output[:, :, num_freq_bin:]

                # clean = [target[:,:num_freq_bin] for target in target_batch]
                # noise = [target[:,num_freq_bin:] for target in target_batch]

                num_outputs = len(clean_outputs)

                results = []

                for i in xrange(num_outputs):
                    orig_clean_output = clean_outputs[i]
                    orig_noise_output = noise_outputs[i]

                    stft_settings = copy.deepcopy(settings[i])
                    orig_length = stft_settings['orig_length']
                    stft_settings.pop('orig_length', None)
                    clean_output = orig_clean_output[-orig_length:]
                    noise_output = orig_noise_output[-orig_length:]

                    clean_mask, noise_mask = create_mask(
                        clean_output, noise_output)
                    orig_clean_mask, orig_noise_mask = create_mask(
                        orig_clean_output, orig_noise_output)

                    curr_mask_array.append(clean_mask)
                    # if i == 0:
                    # print clean_mask[10:20,10:20]
                    curr_mask_array.append(noise_mask)

                    clean_spec = createSpectrogram(
                        np.multiply(
                            clean_mask.transpose(), original_combined_batch[0]
                            [i][-orig_length:].transpose()), settings[i])
                    noise_spec = createSpectrogram(
                        np.multiply(
                            noise_mask.transpose(), original_combined_batch[0]
                            [i][-orig_length:].transpose()), settings[i])

                    # print '-' * 20
                    # print original_combined_batch[0][i]
                    # print '=' * 20
                    combined_batch[0][i] += np.multiply(
                        orig_clean_mask, original_combined_batch[0][i]) * 0.1
                    # print combined_batch[0][i]
                    # print '=' * 20
                    # print original_combined_batch[0][i]
                    # print '-' * 20

                    estimated_clean_wav = stft.ispectrogram(clean_spec)
                    estimated_noise_wav = stft.ispectrogram(noise_spec)

                    reference_clean_wav = stft.ispectrogram(
                        SpectrogramArray(clean[i][-orig_length:],
                                         stft_settings).transpose())
                    reference_noise_wav = stft.ispectrogram(
                        SpectrogramArray(noise[i][-orig_length:],
                                         stft_settings).transpose())

                    try:
                        sdr, sir, sar, _ = bss_eval_sources(
                            np.array(
                                [reference_clean_wav, reference_noise_wav]),
                            np.array(
                                [estimated_clean_wav, estimated_noise_wav]),
                            False)
                        results.append(
                            (sdr[0], sdr[1], sir[0], sir[1], sar[0], sar[1]))
                        # print('%f, %f, %f, %f, %f, %f' % (sdr[0], sdr[1], sir[0], sir[1], sar[0], sar[1]))
                    except ValueError:
                        print('error')
                        continue
                break

                # diff = 1
                # if prev_mask_array is not None:
                #     # print curr_mask_array[0]
                #     # print prev_mask_array[0]
                #     diff = sum(np.sum(np.abs(curr_mask_array[i] - prev_mask_array[i])) for i in xrange(len(prev_mask_array)))
                #     print('Changes after iteration %d: %d' % (iters, diff))

                # sdr_cleans, sdr_noises, sir_cleans, sir_noises, sar_cleans, sar_noises = zip(*results)
                # print('Avg sdr_cleans: %f, sdr_noises: %f, sir_cleans: %f, sir_noises: %f, sar_cleans: %f, sar_noises: %f' % (np.mean(sdr_cleans), np.mean(sdr_noises), np.mean(sir_cleans), np.mean(sir_noises), np.mean(sar_cleans), np.mean(sar_noises)))

                # prev_mask_array = [copy.deepcopy(mask[:,:]) for mask in curr_mask_array]

                # if diff == 0:
                #     break

            results_filename = '%sresults_%d_%f' % (
                'data/results/', Config.num_layers, Config.lr)
            # results_filename += 'freq_weighted'

            with open(results_filename + '.csv', 'w+') as f:
                for sdr_1, sdr_2, sir_1, sir_2, sar_1, sar_2 in results:
                    f.write('%f,%f,%f,%f,%f,%f\n' %
                            (sdr_1, sdr_2, sir_1, sir_2, sar_1, sar_2))
示例#4
0
def model_test(test_input):
    test_rate, test_audio = wavfile.read(test_input)
    clean_rate, clean_audio = wavfile.read(CLEAN_FILE)
    noise_rate, noise_audio = wavfile.read(NOISE_FILE)

    length = len(clean_audio)
    noise_audio = noise_audio[:length]

    clean_spec = stft.spectrogram(clean_audio)
    noise_spec = stft.spectrogram(noise_audio)
    test_spec = stft.spectrogram(test_audio)

    reverted_clean = stft.ispectrogram(clean_spec)
    reverted_noise = stft.ispectrogram(noise_spec)

    test_data = np.array([test_spec.transpose() / 100000
                          ])  # make data a batch of 1

    with tf.Graph().as_default():
        model = SeparationModel()
        saver = tf.train.Saver(tf.trainable_variables())

        with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state('checkpoints/')
            if ckpt:
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(session, ckpt.model_checkpoint_path)
            else:
                print("Created model with fresh parameters.")
                session.run(tf.initialize_all_variables())

            test_data_shape = np.shape(test_data)
            dummy_target = np.zeros((test_data_shape[0], test_data_shape[1],
                                     2 * test_data_shape[2]))

            output, _, _ = model.train_on_batch(session,
                                                test_data,
                                                dummy_target,
                                                train=False)

            num_freq_bin = output.shape[2] / 2
            clean_output = output[0, :, :num_freq_bin]
            noise_output = output[0, :, num_freq_bin:]

            clean_mask, noise_mask = create_mask(clean_output, noise_output)

            clean_spec = createSpectrogram(
                np.multiply(clean_mask.transpose(), test_spec),
                test_spec.stft_settings)
            noise_spec = createSpectrogram(
                np.multiply(noise_mask.transpose(), test_spec),
                test_spec.stft_settings)

            clean_wav = stft.ispectrogram(clean_spec)
            noise_wav = stft.ispectrogram(noise_spec)

            sdr, sir, sar, _ = bss_eval_sources(
                np.array([reverted_clean, reverted_noise]),
                np.array([clean_wav, noise_wav]), False)
            print(sdr, sir, sar)

            writeWav('data/test_combined/output_clean.wav', 44100, clean_wav)
            writeWav('data/test_combined/output_noise.wav', 44100, noise_wav)
示例#5
0
def model_test():
    with tf.Graph().as_default():
        train_inputs, train_targets = prepare_data(False)

        model = SeparationModel(
            freq_weighted=False)  # don't use freq_weighted for now

        model.run_on_batch(train_inputs, train_targets)
        print(train_inputs.get_shape())

        init = tf.group(tf.initialize_all_variables(),
                        tf.initialize_local_variables())
        saver = tf.train.Saver()

        with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state('checkpoints/')
            if ckpt:
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(session, ckpt.model_checkpoint_path)
                session.run(tf.initialize_local_variables())
            else:
                session.run(init)
            global_start = time.time()

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            soft_masked_results = []

            print('num trainable parameters: %s' % (np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))

            try:
                step_ii = 0
                while not coord.should_stop():
                    start = time.time()

                    soft_masked_output, batch_cost, masked_loss, summary, target, mixed_spec = session.run(
                        [
                            model.soft_masked_output, model.loss,
                            model.masked_loss, model.merged_summary_op,
                            model.target, model.input
                        ])

                    step_ii += 1
                    duration = time.time() - start

                    print(
                        'Step %d: loss = %.5f masked_loss = %.5f (%.3f sec)' %
                        (step_ii, batch_cost, masked_loss, duration))

                    soft_song_masked, soft_voice_masked = tf.split(
                        soft_masked_output,
                        [Config.num_freq_bins, Config.num_freq_bins],
                        axis=1)
                    # soft_song_masked *= stats[1][0]
                    # soft_song_masked += stats[0][0]
                    # soft_voice_masked *= stats[1][1]
                    # soft_voice_masked += stats[0][1]
                    song_target, voice_target = tf.split(
                        target, [Config.num_freq_bins, Config.num_freq_bins],
                        axis=1)
                    # song_target *= stats[1][0]
                    # song_target += stats[0][0]
                    # voice_target *= stats[1][1]
                    # voice_target += stats[0][1]

                    mixed_spec = mixed_spec[:, :, 1]
                    # mixed_spec *= stats[1][2]
                    # mixed_spec += stats[0][2]

                    result_wav_dir = 'data/results'

                    mixed_audio = create_audio_from_spectrogram(mixed_spec)
                    writeWav(
                        os.path.join(result_wav_dir,
                                     'mixed%d.wav' % (step_ii)),
                        Config.sample_rate, mixed_audio)

                    soft_song_masked_audio = create_audio_from_spectrogram(
                        soft_song_masked)
                    soft_voice_masked_audio = create_audio_from_spectrogram(
                        soft_voice_masked)

                    writeWav(
                        os.path.join(result_wav_dir,
                                     'soft_song_masked%d.wav' % (step_ii)),
                        Config.sample_rate, soft_song_masked_audio)
                    writeWav(
                        os.path.join(result_wav_dir,
                                     'soft_voice_masked%d.wav' % (step_ii)),
                        Config.sample_rate, soft_voice_masked_audio)

                    song_target_audio = create_audio_from_spectrogram(
                        song_target)
                    voice_target_audio = create_audio_from_spectrogram(
                        voice_target)

                    writeWav(
                        os.path.join(result_wav_dir,
                                     'song_target%d.wav' % (step_ii)),
                        Config.sample_rate, song_target_audio)
                    writeWav(
                        os.path.join(result_wav_dir,
                                     'voice_target%d.wav' % (step_ii)),
                        Config.sample_rate, voice_target_audio)

                    # soft_sdr, soft_sir, soft_sar, _ = bss_eval_sources(np.array([song_target_audio, voice_target_audio]), np.array([soft_song_masked_audio, soft_voice_masked_audio]), False)
                    soft_gnsdr, soft_gsir, soft_gsar = bss_eval_global(
                        mixed_audio, song_target_audio, voice_target_audio,
                        soft_song_masked_audio, soft_voice_masked_audio)

                    # masked_results.append([soft_sdr[0], soft_sdr[1], soft_sir[0], soft_sir[1], soft_sar[0],soft_sar[1]])
                    print(soft_gnsdr[0], soft_gnsdr[1], soft_gsir[0],
                          soft_gsir[1], soft_gsar[0], soft_gsar[1])
                    soft_masked_results.append([
                        soft_gnsdr[0], soft_gnsdr[1], soft_gsir[0],
                        soft_gsir[1], soft_gsar[0], soft_gsar[1]
                    ])

            except tf.errors.OutOfRangeError:
                soft_masked_results = np.asarray(soft_masked_results)
                print(np.mean(soft_masked_results, axis=0))
            finally:
                coord.request_stop()

            coord.join(threads)