예제 #1
0
def main(_):
    json_dir = './config.json'
    with open(json_dir) as config_json:
        config = json.load(config_json)

    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    phase_specs = tf.placeholder(
        tf.float32,
        shape=[None, config['context_window_width'], 129, 4],
        name='phase_specs')

    model_settings = model.create_model_settings(
        dim_direction_label=config['dim_direction_label'],
        sample_rate=config["sample_rate"],
        win_len=config['win_len'],
        win_shift=config['win_shift'],
        nDFT=config['nDFT'],
        context_window_width=config['context_window_width'])

    with tf.variable_scope('CNN'):
        predict_logits = model.doa_cnn(phase_specs=phase_specs,
                                       model_settings=model_settings,
                                       is_training=True)
    CNN_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='CNN')

    print('-' * 80)
    print('CNN vars')
    nparams = 0
    for v in CNN_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    tf.global_variables_initializer().run()
    init_local_variable = tf.local_variables_initializer()
    init_local_variable.run()
    if config['start_checkpoint']:
        model.load_variables_from_checkpoint(sess,
                                             config['start_checkpoint'],
                                             var_list=CNN_vars)

    rir_data_dir = config['rir_data_dir']
    rir_file_list = glob.glob(os.path.join(rir_data_dir, "*.wav"))

    reverb = config['reverb']
    reverb.sort()

    room_index = config['room_idx']
    room_index.sort()

    # find testing files
    testing_file_list = glob.glob(
        os.path.join(config['testing_data_dir'], "*.wav"))

    if not len(testing_file_list):
        Exception("No wav files found at " + testing_file_list)
    if not len(rir_file_list):
        Exception("No wav files found at " + rir_data_dir)

    for room_idx, room in enumerate(room_index):

        for reverb_idx, reverb_percent in enumerate(reverb):

            reverb_wav = input_data.gen_moving_direct_wav(
                wav_dir=config['testing_data_dir'],
                rir_dir=config['rir_data_dir'],
                doa_interval=config['direction_range'],
                deg_per_sec=config['deg_per_sec'],
                reverb_percent=reverb_percent,
                room_index=room)

            voiced_idx, voiced_percent = input_data.get_dual_channel_voiced_idx(
                reverb_wav,
                win_len=config['win_len'],
                win_shift=config['win_shift'],
                nDFT=config['nDFT'],
                context_window_width=config['context_window_width'],
                rms_thre=3e-1)

            duration = reverb_wav.shape[1] / 16e3

            testing_specs = input_data.get_reverb_specs(
                reverb_wav=reverb_wav,
                win_len=config['win_len'],
                win_shift=config['win_shift'],
                nDFT=config['nDFT'],
                context_window_width=config['context_window_width'])

            num_frames = testing_specs.shape[0]
            print(num_frames)

            label, label_argmax = input_data.get_moving_wav_labels(
                num_frames, config['win_shift'], config['deg_per_sec'],
                config['direction_range'])

            logits = sess.run(predict_logits,
                              feed_dict={phase_specs: testing_specs})

            testing_predict = eval.get_deg_from_logits(
                logits,
                doa_interval=config['direction_range'],
                num_doa_class=config['dim_direction_label'])

            wavfile.write(filename='./moving.wav',
                          data=np.transpose(reverb_wav),
                          rate=16000)

            time_idx = np.arange(0, len(reverb_wav[0, :]),
                                 math.floor(len(reverb_wav[0, :]) / 5))
            time_text = time_idx * duration / len(reverb_wav[0, :])
            time_text = [str(round(float(label), 2)) for label in time_text]
            idx = range(len(label_argmax))

            label_idx = np.arange(0, len(label_argmax),
                                  math.floor(len(label_argmax) / 5))
            label_text = label_idx * duration / len(label_argmax)
            label_text = [str(round(float(label), 2)) for label in label_text]

            plt.figure(figsize=(20, 10))

            plt.subplot(311)
            plt.xlabel('time (s)')
            plt.xticks(time_idx, time_text)
            plt.ylabel('X')
            plt.ylim(-1, 1)
            plt.plot(reverb_wav[0, :])

            ax = plt.gca()
            ax.xaxis.set_label_coords(1.05, -0.025)

            plt.subplot(312)
            plt.xlabel('time (s)')
            plt.xticks(time_idx, time_text)
            plt.ylim(-1, 1)
            plt.ylabel('Y')
            plt.plot(reverb_wav[1, :])

            ax = plt.gca()
            ax.xaxis.set_label_coords(1.05, -0.025)

            # only plot result for voiced part
            testing_predict = testing_predict.astype(float)
            silent_idx = np.logical_not(voiced_idx)
            testing_predict[silent_idx] = np.nan

            label_argmax = label_argmax.astype(float)
            label_argmax[silent_idx] = np.nan

            plt.subplot(313)
            plt.ylim(0, 140)
            plt.ylabel('DOA / degree')
            plt.xlabel('time (s)')
            plt.xticks(label_idx, label_text)
            plt.plot(idx,
                     label_argmax,
                     'bs',
                     label='ground truth',
                     markersize=2.15)
            plt.plot(idx, testing_predict, 'r.', label='predict', markersize=2)
            plt.legend(loc='upper left')
            plt.grid(True)

            ax = plt.gca()
            ax.xaxis.set_label_coords(1.05, -0.025)

            fig_save_path = os.path.join(
                './figures', 'v4_voiced',
                os.path.basename(config['testing_data_dir']))
            if not os.path.exists(fig_save_path):
                os.makedirs(fig_save_path)
            file_name = 'moving_plot_reverb' + str(
                reverb_percent) + '_room' + str(room) + '.png'
            save_path = os.path.join(fig_save_path, file_name)
            plt.savefig(save_path)
예제 #2
0
def main(_):

    # import config
    json_dir = './config.json'
    with open(json_dir) as config_json:
        config = json.load(config_json)

    # define noisy specs
    input_specs = tf.placeholder(
        tf.float32,
        shape=[None, config['context_window_width'], 257, 2],
        name='specs')

    # define clean specs
    target_specs = tf.placeholder(tf.float32,
                                  shape=[None, 257, 2],
                                  name='ground_truth')

    # create SE-FCN
    with tf.variable_scope('SEFCN'):
        model_out = model.se_fcn(input_specs, config['nDFT'],
                                 config['context_window_width'])
    model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='SEFCN')

    print('-' * 80)
    print('SE-FCN vars')
    nparams = 0
    for v in model_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # define loss and the optimizer
    mse = tf.losses.mean_squared_error(target_specs, model_out)

    sess = tf.InteractiveSession()
    model_path = os.path.join(config['model_dir'], config['param_file'])

    # load model parameters from checkpoint
    model.load_variables_from_checkpoint(sess, model_path)

    # run the test & save the test results

    tf.logging.set_verbosity(tf.logging.ERROR)

    testing_file_list = glob.glob(
        os.path.join(config['test_tedlium_wav_dir'], "*.wav"))
    print('testing set size: ', len(testing_file_list))
    test_snr = config['test_snr']

    for file_idx, testing_file in enumerate(testing_file_list):

        _, clean_wav = input_data.read_wav(testing_file,
                                           config['sampling_rate'])
        stm_path = os.path.join(
            config['test_stm_path'],
            os.path.basename(testing_file).split(".wav")[0] + '.stm')
        utter_pos = input_data.get_utter_pos(stm_path, config['sampling_rate'])

        for noise_idx in range(len(config['test_noise'])):

            noise_wav_path = os.path.join(
                config['test_noise_path'],
                config['test_noise'][noise_idx] + '.wav')
            _, noise_wav = input_data.read_wav(noise_wav_path,
                                               config['sampling_rate'])

            for snr_idx in range(len(test_snr)):

                for utter_index in range(config['how_many_testing_utter']):
                    utter_wav, noisy_wav, _, _ = input_data.get_noisy_wav_tedlium(
                        clean_wav=clean_wav,
                        noise_wav=noise_wav,
                        utter_pos=utter_pos,
                        pos_index=utter_index,
                        snr=test_snr[snr_idx],
                        utter_percentage=config['speech_percentage'])

                    segment = int(
                        math.ceil(
                            len(noisy_wav) / (config['wav_length_per_batch'] *
                                              config['sampling_rate'])))

                    for segment_idx in range(segment):
                        noisy_specs, clean_specs = input_data.get_seg_specs(
                            mix_wav=noisy_wav,
                            utter_wav=utter_wav,
                            wav_length_per_seg=config['wav_length_per_batch'],
                            seg_idx=segment_idx,
                            win_len=config['win_len'],
                            win_shift=config['win_shift'],
                            context_window_width=config[
                                'context_window_width'],
                            fs=config['sampling_rate'],
                            nDFT=config['nDFT'])

                        seg_specs, seg_mse = sess.run([model_out, mse],
                                                      feed_dict={
                                                          input_specs:
                                                          noisy_specs,
                                                          target_specs:
                                                          clean_specs
                                                      })

                        print("processing file: " + testing_file, " " * 5,
                              "seg:", "{}/{}".format(segment_idx + 1, segment),
                              " " * 5, "proc num batch:", input_specs.shape[0],
                              " " * 5, "seg mse:", format(seg_mse, '.5f'))

                        seg_specs = np.vstack(seg_specs)
                        seg_specs_real = seg_specs[:, :, 0]
                        seg_specs_imag = seg_specs[:, :, 1]

                        if segment_idx == 0:
                            rec_test_out_real = seg_specs_real
                            rec_test_out_imag = seg_specs_imag
                        else:
                            rec_test_out_real = np.concatenate(
                                (rec_test_out_real, seg_specs_real), axis=0)
                            rec_test_out_imag = np.concatenate(
                                (rec_test_out_imag, seg_specs_imag), axis=0)

                    rec_wav = output_data.rec_wav(
                        mag_spec=rec_test_out_real,
                        spec_imag=rec_test_out_imag,
                        win_len=config['win_len'],
                        win_shift=config['win_shift'],
                        nDFT=config['nDFT'])

                    save_path = os.path.join(
                        config['save_testing_results_dir'], 'test',
                        str(config['test_noise'][noise_idx]),
                        str(test_snr[snr_idx]))
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    comp_save_path = os.path.join(
                        save_path,
                        os.path.basename(testing_file).split(".wav")[0] +
                        '_U' + str(utter_index) + '.wav')
                    output_data.save_wav_file(comp_save_path, rec_wav,
                                              config['sampling_rate'])

                    save_path = os.path.join(
                        config['save_testing_results_dir'], 'mix',
                        str(config['test_noise'][noise_idx]),
                        str(test_snr[snr_idx]))
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    comp_save_path = os.path.join(
                        save_path,
                        os.path.basename(testing_file).split(".wav")[0] +
                        '_U' + str(utter_index) + '.wav')
                    output_data.save_wav_file(comp_save_path, noisy_wav,
                                              config['sampling_rate'])

                    save_path = os.path.join(
                        config['save_testing_results_dir'], 'clean',
                        str(config['test_noise'][noise_idx]),
                        str(test_snr[snr_idx]))
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    comp_save_path = os.path.join(
                        save_path,
                        os.path.basename(testing_file).split(".wav")[0] +
                        '_U' + str(utter_index) + '.wav')
                    output_data.save_wav_file(comp_save_path, utter_wav,
                                              config['sampling_rate'])

    np.set_printoptions(precision=3, suppress=True)
예제 #3
0
파일: train.py 프로젝트: ziippy/vai-101
def main(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    labels = FLAGS.labels.split(',')
    label_count = len(labels)

    # Place data loading and preprocessing on the cpu
    with tf.device('/cpu:0'):
        raw_data = Data(FLAGS.data_dir, labels, FLAGS.validation_percentage,
                        FLAGS.testing_percentage)

        tr_data = ImageDataGenerator(raw_data.get_data('training'),
                                     raw_data.get_label_to_index(),
                                     FLAGS.batch_size)

        val_data = ImageDataGenerator(raw_data.get_data('validation'),
                                      raw_data.get_label_to_index(),
                                      FLAGS.batch_size)

        te_data = ImageDataGenerator(raw_data.get_data('testing'),
                                     raw_data.get_label_to_index(),
                                     FLAGS.batch_size)

        # create an reinitializable iterator given the dataset structure
        iterator = tf.data.Iterator.from_structure(
            tr_data.dataset.output_types, tr_data.dataset.output_shapes)
        next_batch = iterator.get_next()

    # Ops for initializing the two different iterators
    training_init_op = iterator.make_initializer(tr_data.dataset)
    validation_init_op = iterator.make_initializer(val_data.dataset)
    testing_init_op = iterator.make_initializer(te_data.dataset)

    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_epochs=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    training_epochs_list = list(
        map(int, FLAGS.how_many_training_epochs.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_epochs_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_epochs and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_epochs_list), len(learning_rates_list)))

    input_xs = tf.placeholder(tf.float32,
                              [None, FLAGS.image_hw, FLAGS.image_hw, 3],
                              name='input_xs')
    logits, dropout_prob = model.create_model(input_xs,
                                              label_count,
                                              FLAGS.model_architecture,
                                              is_training=True)

    # Define loss and optimizer
    ground_truth_input = tf.placeholder(tf.int64, [None],
                                        name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)
    tf.summary.scalar('cross_entropy', cross_entropy_mean)
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        momentum = tf.placeholder(tf.float32, [], name='momentum')
        # train_step = tf.train.GradientDescentOptimizer(learning_rate_input).minimize(cross_entropy_mean)
        # train_step = tf.train.MomentumOptimizer(learning_rate_input, momentum, use_nesterov=True).minimize(cross_entropy_mean)
        # train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(cross_entropy_mean)
        # train_step = tf.train.AdadeltaOptimizer(learning_rate_input).minimize(cross_entropy_mean)
        train_step = tf.train.RMSPropOptimizer(
            learning_rate_input, momentum).minimize(cross_entropy_mean)

    predicted_indices = tf.argmax(logits, 1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.confusion_matrix(ground_truth_input,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    tf.global_variables_initializer().run()

    start_epoch = 1
    start_checkpoint_epoch = 0
    if FLAGS.start_checkpoint:
        model.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        tmp = FLAGS.start_checkpoint
        tmp = tmp.split('-')
        tmp.reverse()
        start_checkpoint_epoch = int(tmp[0])
        start_epoch = start_checkpoint_epoch + 1

    # calculate training epochs max
    training_epochs_max = np.sum(training_epochs_list)

    if start_checkpoint_epoch != training_epochs_max:
        tf.logging.info('Training from epoch: %d ', start_epoch)

    # Saving as Protocol Buffer (pb)
    # tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
    #                      FLAGS.model_architecture + '.pbtxt')
    tf.train.write_graph(sess.graph_def,
                         FLAGS.train_dir,
                         FLAGS.model_architecture + '.pb',
                         as_text=False)

    # Save list of words.
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(raw_data.labels_list))

    # Get the number of training/validation steps per epoch
    tr_batches_per_epoch = int(tr_data.data_size / FLAGS.batch_size)
    if tr_data.data_size % FLAGS.batch_size > 0:
        tr_batches_per_epoch += 1
    val_batches_per_epoch = int(val_data.data_size / FLAGS.batch_size)
    if val_data.data_size % FLAGS.batch_size > 0:
        val_batches_per_epoch += 1
    te_batches_per_epoch = int(te_data.data_size / FLAGS.batch_size)
    if te_data.data_size % FLAGS.batch_size > 0:
        te_batches_per_epoch += 1

    ############################
    # Training loop.
    ############################
    for training_epoch in xrange(start_epoch, training_epochs_max + 1):
        # Figure out what the current learning rate is.
        training_epochs_sum = 0
        for i in range(len(training_epochs_list)):
            training_epochs_sum += training_epochs_list[i]
            if training_epoch <= training_epochs_sum:
                learning_rate_value = learning_rates_list[i]
                break

        # Initialize iterator with the training dataset
        sess.run(training_init_op)
        for step in range(tr_batches_per_epoch):
            # Pull the image samples we'll use for training.
            train_batch_xs, train_batch_ys = sess.run(next_batch)
            # Run the graph with this batch of training data.
            train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
                [
                    merged_summaries, evaluation_step, cross_entropy_mean,
                    train_step, increment_global_step
                ],
                feed_dict={
                    input_xs: train_batch_xs,
                    ground_truth_input: train_batch_ys,
                    learning_rate_input: learning_rate_value,
                    momentum: 0.95,
                    dropout_prob: 0.5
                })

            train_writer.add_summary(train_summary, step)
            tf.logging.info(
                'Epoch #%d, Step #%d: rate %f, accuracy %.1f%%, cross entropy %f'
                % (training_epoch, step, learning_rate_value,
                   train_accuracy * 100, cross_entropy_value))

        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.datetime.now()))
        # Reinitialize iterator with the validation dataset
        sess.run(validation_init_op)
        total_val_accuracy = 0
        validation_count = 0
        total_conf_matrix = None
        for i in range(val_batches_per_epoch):
            validation_batch_xs, validation_batch_ys = sess.run(next_batch)
            # Run a validation step and capture training summaries for TensorBoard
            # with the `merged` op.
            validation_summary, validation_accuracy, conf_matrix = sess.run(
                [merged_summaries, evaluation_step, confusion_matrix],
                feed_dict={
                    input_xs: validation_batch_xs,
                    ground_truth_input: validation_batch_ys,
                    dropout_prob: 1.0
                })

            validation_writer.add_summary(validation_summary, training_epoch)

            total_val_accuracy += validation_accuracy
            validation_count += 1
            if total_conf_matrix is None:
                total_conf_matrix = conf_matrix
            else:
                total_conf_matrix += conf_matrix

        total_val_accuracy /= validation_count

        tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
        tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                        (training_epoch, total_val_accuracy * 100,
                         raw_data.get_size('validation')))

        # Save the model checkpoint periodically.
        if (training_epoch % FLAGS.save_step_interval == 0
                or training_epoch == training_epochs_max):
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.model_architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_epoch)
            saver.save(sess, checkpoint_path, global_step=training_epoch)

    ############################
    # For Evaluate
    ############################
    start = datetime.datetime.now()
    print("{} Start testing".format(start))
    # Reinitialize iterator with the Evaluate dataset
    sess.run(testing_init_op)

    total_test_accuracy = 0
    test_count = 0
    total_conf_matrix = None
    for i in range(te_batches_per_epoch):
        test_batch_xs, test_batch_ys = sess.run(next_batch)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                input_xs: test_batch_xs,
                ground_truth_input: test_batch_ys,
                dropout_prob: 1.0
            })

        total_test_accuracy += test_accuracy
        test_count += 1

        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix

    total_test_accuracy /= test_count

    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_test_accuracy * 100, raw_data.get_size('testing')))

    end = datetime.datetime.now()
    print('End testing: ', end)
    print('total testing time: ', end - start)

    ############################
    # start prediction
    ############################
    print("{} Start prediction".format(datetime.datetime.now()))
    id2name = {i: name for i, name in enumerate(labels)}
    submission = dict()

    # Place data loading and preprocessing on the cpu
    raw_data2 = prediction_data.Data(FLAGS.prediction_data_dir)
    pre_data = prediction_data.ImageDataGenerator(raw_data2.get_data(),
                                                  FLAGS.prediction_batch_size)

    # create an reinitializable iterator given the dataset structure
    iterator = tf.data.Iterator.from_structure(pre_data.dataset.output_types,
                                               pre_data.dataset.output_shapes)
    next_batch = iterator.get_next()

    # Ops for initializing the two different iterators
    prediction_init_op = iterator.make_initializer(pre_data.dataset)

    # Get the number of training/validation steps per epoch
    pre_batches_per_epoch = int(pre_data.data_size /
                                FLAGS.prediction_batch_size)
    if pre_data.data_size % FLAGS.prediction_batch_size > 0:
        pre_batches_per_epoch += 1

    count = 0
    # Initialize iterator with the prediction dataset
    sess.run(prediction_init_op)
    for i in range(pre_batches_per_epoch):
        fingerprints, fnames = sess.run(next_batch)
        prediction = sess.run([predicted_indices],
                              feed_dict={
                                  input_xs: fingerprints,
                                  dropout_prob: 1.0
                              })
        size = len(fnames)
        for n in xrange(0, size):
            submission[fnames[n].decode('UTF-8')] = id2name[prediction[0][n]]

        count += size
        print(count, ' completed')

    # make submission.csv
    if not os.path.exists(FLAGS.result_dir):
        os.makedirs(FLAGS.result_dir)

    fout = open(os.path.join(
        FLAGS.result_dir, 'submission_' + FLAGS.model_architecture + '_' +
        FLAGS.how_many_training_epochs + '.csv'),
                'w',
                encoding='utf-8',
                newline='')
    writer = csv.writer(fout)
    writer.writerow(['file', 'species'])
    for key in sorted(submission.keys()):
        writer.writerow([key, submission[key]])
    fout.close()
예제 #4
0
def main(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    labels = FLAGS.labels.split(',')
    label_count = len(labels)

    training_epochs_list = list(map(int, FLAGS.how_many_training_epochs.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_epochs_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_epochs and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' % (len(training_epochs_list),
                                                       len(learning_rates_list)))

    input_xs = tf.placeholder(
        tf.float32, [None, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='input_xs')

    logits, dropout_prob = models.create_model(
        input_xs,
        label_count,
        FLAGS.model_architecture,
        is_training=True)

    # Define loss and optimizer
    ground_truth_input = tf.placeholder(tf.int64, [None], name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)
    tf.summary.scalar('cross_entropy', cross_entropy_mean)
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        learning_rate_input = tf.placeholder(tf.float32, [], name='learning_rate_input')
        momentum = tf.placeholder(tf.float32, [], name='momentum')
        # train_step = tf.train.GradientDescentOptimizer(learning_rate_input).minimize(cross_entropy_mean)
        train_step = tf.train.MomentumOptimizer(learning_rate_input, momentum, use_nesterov=True).minimize(cross_entropy_mean)
        # train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(cross_entropy_mean)
        # train_step = tf.train.AdadeltaOptimizer(learning_rate_input).minimize(cross_entropy_mean)
        # train_step = tf.train.RMSPropOptimizer(learning_rate_input, momentum).minimize(cross_entropy_mean)

    predicted_indices = tf.argmax(logits, 1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.confusion_matrix(
        ground_truth_input, predicted_indices, num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    merged_summaries = tf.summary.merge_all()

    tf.global_variables_initializer().run()


    ############################
    # start prediction
    ############################
    print("{} Start prediction".format(datetime.datetime.now()))

    id2name = {i: name for i, name in enumerate(labels)}
    submission = dict()

    # Place data loading and preprocessing on the cpu
    raw_data2 = prediction_data.Data(FLAGS.prediction_data_dir)
    pre_data = prediction_data.ImageDataGenerator(raw_data2.get_data(),
                                                  FLAGS.prediction_batch_size)

    # create an reinitializable iterator given the dataset structure
    iterator = tf.data.Iterator.from_structure(pre_data.dataset.output_types,
                                               pre_data.dataset.output_shapes)
    next_batch = iterator.get_next()

    # Ops for initializing the two different iterators
    prediction_init_op = iterator.make_initializer(pre_data.dataset)

    # Get the number of training/validation steps per epoch
    pre_batches_per_epoch = int(np.floor(pre_data.data_size / FLAGS.prediction_batch_size)) + 1

    print("Test Size : {}".format(raw_data2.get_size()))

    count = 0;
    sess.run(prediction_init_op)
    ckpt_list = FLAGS.ckpt_list.split(',')
    ckpt_size = len(ckpt_list)

    for i in range(pre_batches_per_epoch):

        pred_labels = []
        pred_xs, fnames = sess.run(next_batch)

        for j in range(ckpt_size):
          models.load_variables_from_checkpoint(sess, ckpt_list[j])

          prediction, predicted_label = sess.run([predicted_indices, logits],
                                feed_dict={
                                    input_xs: pred_xs,
                                    dropout_prob: 1.0
                                })

          pred_prob = tf.nn.softmax(predicted_label)
          pred_labels.append(sess.run(pred_prob))

        pred_label_array = np.array(pred_labels)
        ensemble_pred_labels = np.mean(pred_label_array, axis = 0)
        ensemble_class_pred = np.argmax(ensemble_pred_labels, axis = 1)

        size = len(fnames)
        for n in xrange(0, size):
            submission[fnames[n].decode('UTF-8')] = id2name[ensemble_class_pred[n]]

        count += size
        print(count, ' completed')

    # make submission.csv
    if not os.path.exists(FLAGS.result_dir):
        os.makedirs(FLAGS.result_dir)

    fout = open(os.path.join(FLAGS.result_dir,
                             'submission_' + FLAGS.model_architecture + '_ensemble_1_3.csv'),
                'w', encoding='utf-8', newline='')
    writer = csv.writer(fout)
    writer.writerow(['file', 'species'])
    for key in sorted(submission.keys()):
        writer.writerow([key, submission[key]])
    fout.close()
예제 #5
0
def main(_):

    # import config
    json_dir = './config.json'
    with open(json_dir) as config_json:
        config = json.load(config_json)

    # define noisy specs
    input_specs = tf.placeholder(
        tf.float32,
        shape=[None, config['context_window_width'], 257, 2],
        name='specs')

    # create SE-FCN
    with tf.variable_scope('SEFCN'):
        model_out = model.se_fcn(input_specs, config['nDFT'],
                                 config['context_window_width'])
    model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='SEFCN')

    print('-' * 80)
    print('SE-FCN vars')
    nparams = 0
    for v in model_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    sess = tf.InteractiveSession()
    model_path = os.path.join(config['model_dir'], config['param_file'])

    # load model parameters from checkpoint
    model.load_variables_from_checkpoint(sess, model_path)

    for root, dirs, files in os.walk(config['recording_dir']):
        for basename in files:

            if not fnmatch.fnmatch(basename, '*.wav'):
                continue

            subdir = os.path.basename(os.path.normpath(root))
            filename = os.path.join(root, basename)
            fs, mix_wav = wavfile.read(filename)
            mix_wav = mix_wav / (2**15 - 1)
            max_amp = np.max(np.abs(mix_wav))

            segment = int(
                math.ceil(
                    len(mix_wav) / (config['seg_recording_length'] * fs)))
            for segment_idx in range(segment):

                testing_specs = input_data.get_seg_testing_specs(
                    mix_wav=mix_wav,
                    fs=fs,
                    wav_length_per_seg=config['seg_recording_length'],
                    seg_idx=segment_idx,
                    win_len=config['win_len'],
                    win_shift=config['win_shift'],
                    nDFT=config['nDFT'],
                    context_window=config['context_window_width'])

                seg_specs = sess.run([model_out],
                                     feed_dict={input_specs: testing_specs})
                print("processing file: " + filename, " " * 5, "seg:",
                      "{}/{}".format(segment_idx + 1, segment), " " * 5,
                      "proc num batch:", testing_specs.shape[0])

                seg_specs = np.vstack(seg_specs)
                seg_specs_real = seg_specs[:, :, 0]
                seg_specs_imag = seg_specs[:, :, 1]

                if segment_idx == 0:
                    rec_test_out_real = seg_specs_real
                    rec_test_out_imag = seg_specs_imag
                else:
                    rec_test_out_real = np.concatenate(
                        (rec_test_out_real, seg_specs_real), axis=0)
                    rec_test_out_imag = np.concatenate(
                        (rec_test_out_imag, seg_specs_imag), axis=0)

            rec_wav = output_data.rec_wav(mag_spec=rec_test_out_real,
                                          spec_imag=rec_test_out_imag,
                                          win_len=config['win_len'],
                                          win_shift=config['win_shift'],
                                          nDFT=config['nDFT'])
            rec_wav = rec_wav * max_amp
            save_path = os.path.join(config['save_processed_recordings_dir'],
                                     subdir)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            comp_save_path = os.path.join(save_path, basename)
            output_data.save_wav_file(comp_save_path, rec_wav, fs)

    np.set_printoptions(precision=3, suppress=True)
예제 #6
0
파일: train.py 프로젝트: phpstorm1/DOA-FCN
def main(_):

    json_dir = './config.json'
    with open(json_dir) as config_json:
        config = json.load(config_json)

    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    phase_specs = tf.placeholder(
        tf.float32,
        shape=[None, config['context_window_width'], 129, 4],
        name='phase_specs')
    ground_truth_doa_label = tf.placeholder(
        tf.float32,
        shape=[None, config['dim_direction_label']],
        name='ground_truth_input')

    model_settings = model.create_model_settings(
        dim_direction_label=config['dim_direction_label'],
        sample_rate=config["sample_rate"],
        win_len=config['win_len'],
        win_shift=config['win_shift'],
        nDFT=config['nDFT'],
        context_window_width=config['context_window_width'])

    with tf.variable_scope('CNN'):
        predict_logits = model.doa_cnn(phase_specs=phase_specs,
                                       model_settings=model_settings,
                                       is_training=True)
    CNN_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='CNN')

    print('-' * 80)
    print('CNN vars')
    nparams = 0
    for v in CNN_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    cross_entropy = tf.losses.softmax_cross_entropy(
        onehot_labels=ground_truth_doa_label, logits=predict_logits)

    mean_cross_entropy = tf.reduce_mean(cross_entropy)
    acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(ground_truth_doa_label,
                                                       1),
                                      predictions=tf.argmax(predict_logits, 1))
    pc_acc, pc_acc_op = tf.metrics.mean_per_class_accuracy(
        labels=tf.argmax(ground_truth_doa_label, 1),
        predictions=tf.argmax(predict_logits, 1),
        num_classes=config['dim_direction_label'])
    tf.summary.scalar('cross_entropy', mean_cross_entropy)
    tf.summary.scalar('class_accuracy', acc_op)
    tf.summary.histogram('per_class_accuracy', pc_acc_op)

    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    global_step = tf.train.get_or_create_global_step()
    with tf.name_scope('train'), tf.control_dependencies(extra_update_ops):
        adam = tf.train.AdamOptimizer(config['Adam_learn_rate'])
        # rms = tf.train.RMSPropOptimizer(config['Adam_learn_rate'])
        train_step = adam.minimize(cross_entropy,
                                   global_step=global_step,
                                   var_list=CNN_vars)
        # train_step = rms.minimize(cross_entropy, global_step=global_step, var_list=CNN_vars)

    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(config['summaries_dir'], sess.graph)

    # tf.global_variables_initializer().run()
    start_step = 1

    tf.global_variables_initializer().run()
    init_local_variable = tf.local_variables_initializer()
    init_local_variable.run()

    if config['start_checkpoint']:
        model.load_variables_from_checkpoint(sess, config['start_checkpoint'])
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    tf.train.write_graph(sess.graph_def, config['train_dir'], 'model.pbtxt')

    # find training files
    training_data_dir = config['training_data_dir']
    training_file_list = glob.glob(os.path.join(training_data_dir, "*.wav"))

    training_speech_dir = config['training_speech_dir']
    training_speech_list = glob.glob(os.path.join(training_speech_dir, "**",
                                                  "*.wav"),
                                     recursive=True)

    rir_data_dir = config['rir_data_dir']
    rir_file_list = glob.glob(os.path.join(rir_data_dir, "*.wav"))

    reverb = config['reverb']
    reverb.sort()

    room_index = config['room_idx']
    room_index.sort()

    # find testing files
    testing_file_list = glob.glob(
        os.path.join(config['testing_data_dir'], "*.wav"))

    if not len(training_file_list):
        Exception("No wav files found at " + training_data_dir)
    if not len(rir_file_list):
        Exception("No wav files found at " + rir_data_dir)

    tf.logging.info("Number of training wav files: %d",
                    len(training_file_list))

    # Training loop.
    how_many_training_steps = config['how_many_training_steps']
    for training_step in range(start_step, int(how_many_training_steps + 1)):

        training_file_idx = random.randint(0, len(training_file_list) - 1)
        # rir_idx = random.randint(0, len(rir_file_list)-1)
        # rir_idx = training_step % (1+config['direction_range'][1])
        rir_idx = training_step % len(rir_file_list)

        training_filename = training_file_list[training_file_idx]
        rir_filename = rir_file_list[rir_idx]

        reverb_percent = int(
            rir_filename.split('reverb_')[1].split('Percent_')[0])
        if reverb_percent == 75 or reverb_percent == 65:
            if random.randint(0, 1):
                speech_file_idx = random.randint(0,
                                                 len(training_speech_list) - 1)
                training_filename = training_speech_list[speech_file_idx]

        reverb_wav, training_phase_specs = input_data.get_input_specs(
            training_filename, rir_filename, config['win_len'],
            config['win_shift'], config['nDFT'],
            config['context_window_width'], config['max_wav_length'])
        num_frames = training_phase_specs.shape[0]

        training_doa_label = input_data.get_direction_label(
            rir_filename, config['dim_direction_label'],
            config['direction_range'], num_frames)

        training_summary, training_cross_entropy, _, _ = sess.run(
            [
                merged_summaries, mean_cross_entropy, train_step,
                init_local_variable
            ],
            feed_dict={
                phase_specs: training_phase_specs,
                ground_truth_doa_label: training_doa_label
            })

        print("Step: ", training_step, " " * 10, "cross entropy: ",
              format(training_cross_entropy, '.5f'), " " * 10, "rir: ",
              format(reverb_percent, '2.0f'), " " * 10, "training file: ",
              os.path.basename(training_filename))
        train_writer.add_summary(training_summary, training_step)

        # Save the model checkpoint periodically.
        if training_step % config[
                'save_step_interval'] == 0 or training_step == how_many_training_steps:
            checkpoint_path = os.path.join(config['train_dir'], 'model.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    set_size = len(testing_file_list)
    tf.logging.info('testing set size=%d', set_size)

    doa_per_reverb = int(
        max(config['direction_range']) - min(config['direction_range']) + 1)
    test_acc = np.zeros(
        [len(testing_file_list), doa_per_reverb,
         len(reverb),
         len(room_index)])
    test_adj_acc = np.zeros(
        [len(testing_file_list), doa_per_reverb,
         len(reverb),
         len(room_index)])
    test_frame_acc = np.zeros(
        [len(testing_file_list), doa_per_reverb,
         len(reverb),
         len(room_index)])
    test_adj_frame_acc = np.zeros(
        [len(testing_file_list), doa_per_reverb,
         len(reverb),
         len(room_index)])
    for testing_file_idx, testing_file in enumerate(testing_file_list):
        print("testing file:", os.path.basename(testing_file))
        for rir_file_idx, rir_file in enumerate(rir_file_list):
            rir_filename = os.path.basename(rir_file)
            degree = int(rir_filename.split('angle_')[1].split('deg_')[0])
            reverb_percent = int(
                rir_filename.split('reverb_')[1].split('Percent_')[0])
            room_num = int(rir_filename.split('_ROOM')[1].split('.wav')[0])

            if reverb_percent not in reverb or room_num not in room_index:
                continue
            reverb_idx = reverb.index(reverb_percent)
            room_idx = room_index.index(room_num)

            reverb_wav, testing_phase_specs = input_data.get_input_specs(
                testing_file, rir_file, config['win_len'], config['win_shift'],
                config['nDFT'], config['context_window_width'],
                config['max_wav_length'])
            num_frames = testing_phase_specs.shape[0]

            testing_doa_label = input_data.get_direction_label(
                rir_file, config['dim_direction_label'],
                config['direction_range'], num_frames)

            logits, class_acc, _ = sess.run(
                [predict_logits, acc_op, init_local_variable],
                feed_dict={
                    phase_specs: testing_phase_specs,
                    ground_truth_doa_label: testing_doa_label
                })

            adjacent_class = 2
            how_many_previous_frame = 15
            testing_predict = eval.get_label_from_logits(logits)
            adj_acc = eval.eval_adjacent_accuracy(testing_predict,
                                                  testing_doa_label,
                                                  adjacent_class)
            frame_acc = eval.eval_frame_accuracy(testing_predict,
                                                 testing_doa_label,
                                                 how_many_previous_frame)
            adj_frame_acc = eval.eval_joint_deg_frame(testing_predict,
                                                      testing_doa_label,
                                                      adjacent_class,
                                                      how_many_previous_frame)
            test_acc[testing_file_idx, degree, reverb_idx,
                     room_idx] = class_acc
            test_adj_acc[testing_file_idx, degree, reverb_idx,
                         room_idx] = adj_acc
            test_frame_acc[testing_file_idx, degree, reverb_idx,
                           room_idx] = frame_acc
            test_adj_frame_acc[testing_file_idx, degree, reverb_idx,
                               room_idx] = adj_frame_acc

            print("degree:", format(degree, '5.1f'), " " * 6, "reverb:",
                  format(reverb_percent, '5.0f'), " " * 6, "room:",
                  format(room_num, '5.0f'), " " * 6, "acc:",
                  format(class_acc, '5.5f'), " " * 6, "adj acc:",
                  format(adj_acc, '5.5f'), " " * 6, "frame acc:",
                  format(frame_acc, '5.5f'), " " * 6, "adj frame acc:",
                  format(adj_frame_acc, '5.5f'))
    print("overall acc:", np.mean(test_acc))
    print("overall adj_acc:", np.mean(test_adj_acc))
    print("overall frame_acc:", np.mean(test_frame_acc))
    print("overall adj frame acc:", np.mean(test_adj_frame_acc))
    print("-" * 30)
    print("Degree accuracy")
    print(format("deg", '10.10s'), format("acc", '10.10s'),
          format("deg acc", '15.10s'), format("frame acc", '15.10s'),
          format("deg frame acc", "15.10s"))
    for i in range(doa_per_reverb):
        print(format(i, '.1f'), " " * 5,
              format(np.mean(test_acc[:, i, :]), '.4f'), " " * 6,
              format(np.mean(test_adj_acc[:, i, :]), '.4f'), " " * 6,
              format(np.mean(test_frame_acc[:, i, :]), '.4f'), " " * 6,
              format(np.mean(test_adj_frame_acc[:, i, :]), '.4f'))

    deg_idx = range(doa_per_reverb)

    print("-" * 30)

    for room in range(len(room_index)):
        for i in range(len(reverb)):
            print("reverb: ", reverb[i])
            print("room: ", room_index[room])
            print("acc:", np.mean(test_acc[:, :, i, room]))
            print("adj_acc:", np.mean(test_adj_acc[:, :, i, room]))
            print("frame_acc:", np.mean(test_frame_acc[:, :, i, room]))
            print("adj frame acc:", np.mean(test_adj_frame_acc[:, :, i, room]))

            for j in range(doa_per_reverb):
                print(
                    format(j, '.1f'), " " * 5,
                    format(np.mean(test_acc[:, j, i, room]), '.4f'), " " * 6,
                    format(np.mean(test_adj_acc[:, j, i, room]),
                           '.4f'), " " * 6,
                    format(np.mean(test_frame_acc[:, j, i, room]), '.4f'),
                    " " * 6,
                    format(np.mean(test_adj_frame_acc[:, j, i, room]), '.4f'))

            print("-" * 30)

    for room in range(len(room_index)):
        for i in range(len(reverb)):
            plt.figure(i)
            plt.plot(deg_idx, np.mean(test_adj_acc[:, :, i, room], axis=0),
                     '.')
            plt.yscale('linear')
            plt.xlabel('degree')
            plt.ylabel('accuracy')
            plt.title('room ' + str(room_index[room]) + ', reverb ' +
                      str(reverb[i]) + ' percent')
            plt.grid(True)
            filename = 'room_' + str(room_index[room]) + '_reverb_' + str(
                reverb[i]) + '_acc.png'
            fig_save_path = os.path.join(
                './figures', 'v4',
                os.path.basename(config['testing_data_dir']))
            if not os.path.exists(fig_save_path):
                os.makedirs(fig_save_path)
            filename = os.path.join(fig_save_path, filename)
            plt.savefig(filename)
            plt.show()
예제 #7
0
파일: train.py 프로젝트: GeorgyZhou/TSRC
def main(_):
  # We want to see all the logging messages for this tutorial.
  tf.logging.set_verbosity(tf.logging.INFO)

  # Start a new TensorFlow session.
  sess = tf.InteractiveSession()

  # Begin by making sure we have the training data we need. If you already have
  # training data of your own, use `--data_url= ` on the command line to avoid
  # downloading.
  model_settings = model.prepare_model_settings(
      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
      FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
  audio_processor = input_data.AudioProcessor(
      FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
      FLAGS.unknown_percentage,
      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
      FLAGS.testing_percentage, model_settings)
  fingerprint_size = model_settings['fingerprint_size']
  label_count = model_settings['label_count']
  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
  # Figure out the learning rates for each training phase. Since it's often
  # effective to have high learning rates at the start of training, followed by
  # lower levels towards the end, the number of steps and learning rates can be
  # specified as comma-separated lists to define the rate at each stage. For
  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
  # will run 13,000 training loops in total, with a rate of 0.001 for the first
  # 10,000, and 0.0001 for the final 3,000.
  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
  if len(training_steps_list) != len(learning_rates_list):
    raise Exception(
        '--how_many_training_steps and --learning_rate must be equal length '
        'lists, but are %d and %d long instead' % (len(training_steps_list),
                                                   len(learning_rates_list)))

  fingerprint_input = tf.placeholder(
      tf.float32, [None, fingerprint_size], name='fingerprint_input')

  logits, dropout_prob = model.create_conv_model(fingerprint_input,
                                                 model_settings, is_training=True)

  # Define loss and optimizer
  ground_truth_input = tf.placeholder(
      tf.float32, [None, label_count], name='groundtruth_input')

  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
  # numerical errors start occurring during training.
  control_dependencies = []
  if FLAGS.check_nans:
    checks = tf.add_check_numerics_ops()
    control_dependencies = [checks]

  # Create the back propagation and training evaluation machinery in the graph.
  with tf.name_scope('cross_entropy'):
    cross_entropy_mean = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            labels=ground_truth_input, logits=logits))
  tf.summary.scalar('cross_entropy', cross_entropy_mean)
  with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
    learning_rate_input = tf.placeholder(
        tf.float32, [], name='learning_rate_input')
    train_step = tf.train.GradientDescentOptimizer(
        learning_rate_input).minimize(cross_entropy_mean)
  predicted_indices = tf.argmax(logits, 1)
  expected_indices = tf.argmax(ground_truth_input, 1)
  correct_prediction = tf.equal(predicted_indices, expected_indices)
  confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices, num_classes=label_count)
  evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  tf.summary.scalar('accuracy', evaluation_step)

  global_step = tf.contrib.framework.get_or_create_global_step()
  increment_global_step = tf.assign(global_step, global_step + 1)

  saver = tf.train.Saver(tf.global_variables())

  # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
  merged_summaries = tf.summary.merge_all()
  train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                       sess.graph)
  validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')

  tf.global_variables_initializer().run()

  start_step = 1

  if FLAGS.start_checkpoint:
    model.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
    start_step = global_step.eval(session=sess)

  tf.logging.info('Training from step: %d ', start_step)

  # Save graph.pbtxt.
  tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                       FLAGS.model_architecture + '.pbtxt')

  # Save list of words.
  with gfile.GFile(
      os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
      'w') as f:
    f.write('\n'.join(audio_processor.words_list))

  # Training loop.
  training_steps_max = np.sum(training_steps_list)
  for training_step in xrange(start_step, training_steps_max + 1):
    # Figure out what the current learning rate is.
    training_steps_sum = 0
    for i in range(len(training_steps_list)):
      training_steps_sum += training_steps_list[i]
      if training_step <= training_steps_sum:
        learning_rate_value = learning_rates_list[i]
        break
    # Pull the audio samples we'll use for training.
    train_fingerprints, train_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
        FLAGS.background_volume, time_shift_samples, 'training', sess)
    # Run the graph with this batch of training data.
    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
        [
            merged_summaries, evaluation_step, cross_entropy_mean, train_step,
            increment_global_step
        ],
        feed_dict={
            fingerprint_input: train_fingerprints,
            ground_truth_input: train_ground_truth,
            learning_rate_input: learning_rate_value,
            dropout_prob: 0.5
        })
    train_writer.add_summary(train_summary, training_step)
    tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
                    (training_step, learning_rate_value, train_accuracy * 100,
                     cross_entropy_value))
    is_last_step = (training_step == training_steps_max)
    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
      set_size = audio_processor.set_size('validation')
      total_accuracy = 0
      total_conf_matrix = None
      for i in xrange(0, set_size, FLAGS.batch_size):
        validation_fingerprints, validation_ground_truth = (
            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
                                     0.0, 0, 'validation', sess))
        # Run a validation step and capture training summaries for TensorBoard
        # with the `merged` op.
        validation_summary, validation_accuracy, conf_matrix = sess.run(
            [merged_summaries, evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: validation_fingerprints,
                ground_truth_input: validation_ground_truth,
                dropout_prob: 1.0
            })
        validation_writer.add_summary(validation_summary, training_step)
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (validation_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
          total_conf_matrix = conf_matrix
        else:
          total_conf_matrix += conf_matrix
      tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
      tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                      (training_step, total_accuracy * 100, set_size))

    # Save the model checkpoint periodically.
    if (training_step % FLAGS.save_step_interval == 0 or
        training_step == training_steps_max):
      checkpoint_path = os.path.join(FLAGS.train_dir,
                                     FLAGS.model_architecture + '.ckpt')
      tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
      saver.save(sess, checkpoint_path, global_step=training_step)

  set_size = audio_processor.set_size('testing')
  tf.logging.info('set_size=%d', set_size)
  total_accuracy = 0
  total_conf_matrix = None
  for i in xrange(0, set_size, FLAGS.batch_size):
    test_fingerprints, test_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
    test_accuracy, conf_matrix = sess.run(
        [evaluation_step, confusion_matrix],
        feed_dict={
            fingerprint_input: test_fingerprints,
            ground_truth_input: test_ground_truth,
            dropout_prob: 1.0
        })
    batch_size = min(FLAGS.batch_size, set_size - i)
    total_accuracy += (test_accuracy * batch_size) / set_size
    if total_conf_matrix is None:
      total_conf_matrix = conf_matrix
    else:
      total_conf_matrix += conf_matrix
  tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
  tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100,
                                                           set_size))
예제 #8
0
파일: train.py 프로젝트: phpstorm1/SE-FCN
def main(_):

    # random seed
    RANDOM_SEED = 3233

    # import config
    json_dir = './config.json'
    with open(json_dir) as config_json:
        config = json.load(config_json)

    # define noisy specs
    input_specs = tf.placeholder(tf.float32, shape=[None, config['context_window_width'], 257, 2], name='specs')

    # define clean specs
    train_target = tf.placeholder(tf.float32, shape=[None, 257, 2], name='ground_truth')

    # create SE-FCN
    with tf.variable_scope('SEFCN'):
        model_out = model.se_fcn(input_specs, config['nDFT'], config['context_window_width'])
    model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='SEFCN')

    print('-' * 80)
    print('SE-FCN vars')
    nparams = 0
    for v in model_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
    print('-' * 80)

    # define loss and the optimizer
    mse = tf.losses.mean_squared_error(train_target, model_out)

    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    global_step = tf.train.get_or_create_global_step()

    with tf.name_scope('train'), tf.control_dependencies(extra_update_ops):
        adam = tf.train.AdamOptimizer(config['Adam_learn_rate'])
        train_op = adam.minimize(mse, global_step=global_step, var_list=model_vars)

    # make summaries
    tf.summary.scalar('mse', mse)

    # train the model
    sess = tf.InteractiveSession()
    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(config['summaries_dir'], sess.graph)
    tf.train.write_graph(sess.graph_def, config['train_dir'], 'model.pbtxt')

    tf.global_variables_initializer().run()
    init_local_variable = tf.local_variables_initializer()
    init_local_variable.run()

    start_step = 0
    if config['start_checkpoint']:
        model.load_variables_from_checkpoint(sess, config['start_checkpoint'])
        start_step = global_step.eval(session=sess)
    print('Training from step:', start_step)

    tf.logging.set_verbosity(tf.logging.ERROR)

    snr = range(config['snr_range'][0], config['snr_range'][1] + 1)
    sensor_snr = range(config['sensor_snr_range'][0], config['sensor_snr_range'][1] + 1)

    speech_file_list = glob.glob(os.path.join(config['training_data_dir'], "**", "*.wav"), recursive=True)
    noise_file_list = glob.glob(os.path.join(config['noise_dir'], "**", "*.wav"), recursive=True)

    if not len(speech_file_list):
        Exception("No wav files found at " + config['training_data_dir'])
    if not len(noise_file_list):
        Exception("No wav files found at " + config['noise_dir'])

    # get amp normalized sensor noise data
    if len(config['sensor_noise_path']):
        fs_sensor, sensor_wav = input_data.read_wav(config['sensor_noise_path'], config['sampling_rate'])
    else:
        sensor_wav = None

    print("Number of training speech wav files: ", len(speech_file_list))
    print("Number of training noise wav files: ", len(noise_file_list))

    how_many_training_steps = config['how_many_training_steps']

    random.seed(RANDOM_SEED)
    rand_noise_file_idx_list = [random.randint(0, len(noise_file_list)-1)
                                for i in range(int(how_many_training_steps + 1) * config['batch_size'])]

    random.seed(RANDOM_SEED)
    rand_speech_file_idx_list = [random.randint(0, len(speech_file_list)-1)
                                 for i in range(int(how_many_training_steps + 1) * config['batch_size'])]

    random.seed(RANDOM_SEED)
    snr_idx_list = [random.randint(0, len(snr)-1) for i in range(int(how_many_training_steps + 1) * config['batch_size'])]

    random.seed(RANDOM_SEED)
    sensor_snr_idx_list = [random.randint(0, len(snr)-1) for i in range(int(how_many_training_steps + 1) * config['batch_size'])]

    for training_step in range(start_step+1, int(how_many_training_steps + 1)):

        # get training data
        _, batch_noisy_specs, speech_specs, _ = input_data.get_training_specs(speech_file_list,
                                                                              noise_file_list,
                                                                              snr,
                                                                              rand_speech_file_idx_list,
                                                                              rand_noise_file_idx_list,
                                                                              snr_idx_list,
                                                                              sensor_noise_wav=sensor_wav,
                                                                              sensor_snr=sensor_snr,
                                                                              sensor_snr_idx_list=sensor_snr_idx_list,
                                                                              training_step=training_step,
                                                                              batch_size=config['batch_size'],
                                                                              context_window_width=config['context_window_width'])

        # train the model
        _, training_summary, train_mse = sess.run([train_op, merged_summaries, mse],
                                                  feed_dict={input_specs: batch_noisy_specs, train_target: speech_specs})

        print("training step:", training_step, " "*10, "mse:", format(train_mse, '.5f'))
        train_writer.add_summary(training_summary, training_step)

        # Save the model checkpoint periodically.
        if training_step % config['save_checkpoint_steps'] == 0 or training_step == how_many_training_steps:
            checkpoint_path = os.path.join(config['train_dir'], 'sefcn.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    # run the test & save the test results
    # find testing files
    testing_file_list = glob.glob(os.path.join(config['testing_data_dir'], "*.wav"))
    print('testing set size: ', len(testing_file_list))

    test_snr = config['test_snr']
    overall_testing_mse = np.zeros([len(snr), len(config['test_noise'])])

    for file_idx, testing_file in enumerate(testing_file_list):

        _, clean_wav = input_data.read_wav(testing_file, config['sampling_rate'])

        for noise_idx in range(len(config['test_noise'])):

            noise_wav_path = os.path.join(config['test_noise_path'], config['test_noise'][noise_idx] + '.wav')
            _, noise_wav = input_data.read_wav(noise_wav_path, config['sampling_rate'])

            for snr_idx in range(len(test_snr)):

                utter_wav, noisy_wav = input_data.get_noisy_wav(clean_wav=clean_wav,
                                                                noise_wav=noise_wav,
                                                                snr=test_snr[snr_idx])

                _, batched_noisy_specs, speech_specs, _ = input_data.get_testing_specs(utter_wav,
                                                                                       noisy_wav,
                                                                                       context_window_width=config['context_window_width'])

                estimate_specs, test_mse = sess.run([model_out, mse],
                                                    feed_dict={input_specs: batched_noisy_specs, train_target: speech_specs})

                rec_wav = output_data.rec_wav(mag_spec=estimate_specs[:, :, 0],
                                              spec_imag=estimate_specs[:, :, 1],
                                              win_len=config['win_len'],
                                              win_shift=config['win_shift'],
                                              nDFT=config['nDFT'])

                save_path = os.path.join(config['save_testing_results_dir'],
                                         'test',
                                         str(config['test_noise'][noise_idx]),
                                         str(snr[snr_idx]))

                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                comp_save_path = os.path.join(save_path,
                                              os.path.basename(testing_file))
                output_data.save_wav_file(comp_save_path, rec_wav, config['sampling_rate'])

                save_path = os.path.join(config['save_testing_results_dir'],
                                         'mix',
                                         str(config['test_noise'][noise_idx]),
                                         str(snr[snr_idx]))
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                comp_save_path = os.path.join(save_path,
                                              os.path.basename(testing_file))
                output_data.save_wav_file(comp_save_path, noisy_wav, config['sampling_rate'])

                save_path = os.path.join(config['save_testing_results_dir'],
                                         'clean',
                                         str(config['test_noise'][noise_idx]),
                                         str(snr[snr_idx]))
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                comp_save_path = os.path.join(save_path,
                                              os.path.basename(testing_file))
                output_data.save_wav_file(comp_save_path, utter_wav, config['sampling_rate'])

                print("Testing file #", file_idx, os.path.basename(testing_file),
                      "SNR :", format(snr[snr_idx], '5.1f'), " "*10,
                      "noise:", format(config['test_noise'][noise_idx], '10.10s'), " "*10,
                      "mse:", format(test_mse, '.5f'))

                overall_testing_mse[snr_idx][noise_idx] = test_mse / (len(testing_file_list))

    np.set_printoptions(precision=3, suppress=True)