コード例 #1
0
def peak_picking_subroutine(test_nacta_2017, test_nacta, th, obs_cal, scaler,
                            len_seq, model_name, full_path_model, architecture,
                            detection_results_path, jingju_eval_results_path):
    """routine for peak picking decoding"""
    from src.utilFunctions import append_or_write
    import csv

    eval_result_file_name = join(
        jingju_eval_results_path, varin['sample_weighting'],
        model_name + '_peakPicking_threshold_results.txt')

    list_recall_onset_25, list_precision_onset_25, list_F1_onset_25 = [], [], []
    list_recall_onset_5, list_precision_onset_5, list_F1_onset_5 = [], [], []
    list_recall_25, list_precision_25, list_F1_25 = [], [], []
    list_recall_5, list_precision_5, list_F1_5 = [], [], []

    for ii in range(5):

        stateful = False if varin['overlap'] else True
        if obs_cal == 'tocal':
            input_shape = (1, len_seq, 1, 80, 15)
            # initialize the model
            model_keras_cnn_0 = jan_original(filter_density=1,
                                             dropout=0.5,
                                             input_shape=input_shape,
                                             batchNorm=False,
                                             dense_activation='sigmoid',
                                             channel=1,
                                             stateful=stateful,
                                             training=False,
                                             bidi=varin['bidi'])

            model_keras_cnn_0.load_weights(full_path_model + str(ii) + '.h5')
        else:
            model_keras_cnn_0 = None

        if varin['dataset'] != 'ismir':
            # evaluate nacta 2017 dataset
            batch_process_onset_detection(
                wav_path=nacta2017_wav_path,
                textgrid_path=nacta2017_textgrid_path,
                score_path=nacta2017_score_pinyin_path,
                test_recordings=test_nacta_2017,
                model_keras_cnn_0=model_keras_cnn_0,
                cnnModel_name=model_name + str(ii),
                detection_results_path=detection_results_path + str(ii),
                scaler=scaler,
                threshold=th,
                obs_cal=obs_cal,
                decoding_method='peakPicking',
                architecture=architecture,
                len_seq=len_seq,
                stateful=stateful)

        # evaluate nacta dataset
        eval_results_decoding_path = \
            batch_process_onset_detection(wav_path=nacta_wav_path,
                                          textgrid_path=nacta_textgrid_path,
                                          score_path=nacta_score_pinyin_path,
                                          test_recordings=test_nacta,
                                          model_keras_cnn_0=model_keras_cnn_0,
                                          cnnModel_name=model_name + str(ii),
                                          detection_results_path=detection_results_path + str(ii),
                                          scaler=scaler,
                                          threshold=th,
                                          obs_cal=obs_cal,
                                          decoding_method='peakPicking',
                                          stateful=stateful,
                                          architecture=architecture,
                                          len_seq=len_seq)

        append_write = append_or_write(eval_result_file_name)

        with open(eval_result_file_name, append_write) as testfile:
            csv_writer = csv.writer(testfile)
            csv_writer.writerow([th])

        precision_onset, recall_onset, F1_onset, \
        precision, recall, F1, \
            = eval_write_2_txt(eval_result_file_name,
                               eval_results_decoding_path,
                               label=False,
                               decoding_method='peakPicking')

        list_precision_onset_25.append(precision_onset[0])
        list_precision_onset_5.append(precision_onset[1])
        list_recall_onset_25.append(recall_onset[0])
        list_recall_onset_5.append(recall_onset[1])
        list_F1_onset_25.append(F1_onset[0])
        list_F1_onset_5.append(F1_onset[1])
        list_precision_25.append(precision[0])
        list_precision_5.append(precision[1])
        list_recall_25.append(recall[0])
        list_recall_5.append(recall[1])
        list_F1_25.append(F1[0])
        list_F1_5.append(F1[1])

    return list_precision_onset_25, \
           list_recall_onset_25, \
           list_F1_onset_25, \
           list_precision_25, \
           list_recall_25, \
           list_F1_25, \
           list_precision_onset_5, \
           list_recall_onset_5, \
           list_F1_onset_5, \
           list_precision_5, \
           list_recall_5, \
           list_F1_5
def peak_picking_subroutine(test_nacta_2017, test_nacta, th, obs_cal,
                            architecture, model_name, full_path_model,
                            full_path_scaler, detection_results_path,
                            jingju_eval_results_path):
    """Peak picking routine,
    five folds evaluation"""
    from src.utilFunctions import append_or_write
    import csv

    eval_result_file_name = join(
        jingju_eval_results_path, varin['sample_weighting'],
        model_name + '_peakPicking_threshold_results.txt')

    list_recall_onset_25, list_precision_onset_25, list_F1_onset_25 = [], [], []
    list_recall_onset_5, list_precision_onset_5, list_F1_onset_5 = [], [], []
    list_recall_25, list_precision_25, list_F1_25 = [], [], []
    list_recall_5, list_precision_5, list_F1_5 = [], [], []

    for ii in range(5):

        if obs_cal == 'tocal':
            if 'pretrained' in architecture:
                scaler = cPickle.load(
                    gzip.open(full_path_scaler + str(ii) + '.pickle.gz'))
            else:
                scaler = pickle.load(open(full_path_scaler))

            model_keras_cnn_0 = load_model(full_path_model + str(ii) + '.h5')

            # delete detection results path if it exists
            detection_results_path_model = join(detection_results_path +
                                                str(ii))
            if os.path.exists(detection_results_path_model) and os.path.isdir(
                    detection_results_path + str(ii)):
                shutil.rmtree(detection_results_path + str(ii))

        else:
            model_keras_cnn_0 = None
            scaler = None

        if varin['dataset'] != 'ismir':
            # nacta2017
            batch_process_onset_detection(
                wav_path=nacta2017_wav_path,
                textgrid_path=nacta2017_textgrid_path,
                score_path=nacta2017_score_pinyin_path,
                test_recordings=test_nacta_2017,
                model_keras_cnn_0=model_keras_cnn_0,
                cnnModel_name=model_name + str(ii),
                detection_results_path=detection_results_path + str(ii),
                scaler=scaler,
                architecture=architecture,
                threshold=th,
                obs_cal=obs_cal,
                decoding_method='peakPicking')

        eval_results_decoding_path = \
            batch_process_onset_detection(wav_path=nacta_wav_path,
                                          textgrid_path=nacta_textgrid_path,
                                          score_path=nacta_score_pinyin_path,
                                          test_recordings=test_nacta,
                                          model_keras_cnn_0=model_keras_cnn_0,
                                          cnnModel_name=model_name + str(ii),
                                          detection_results_path=detection_results_path + str(ii),
                                          scaler=scaler,
                                          architecture=architecture,
                                          threshold=th,
                                          obs_cal=obs_cal,
                                          decoding_method='peakPicking')

        append_write = append_or_write(eval_result_file_name)
        with open(eval_result_file_name, append_write) as testfile:
            csv_writer = csv.writer(testfile)
            csv_writer.writerow([th])

        precision_onset, recall_onset, F1_onset, \
        precision, recall, F1, \
            = eval_write_2_txt(eval_result_file_name,
                               eval_results_decoding_path,
                               label=False,
                               decoding_method='peakPicking')

        list_precision_onset_25.append(precision_onset[0])
        list_precision_onset_5.append(precision_onset[1])
        list_recall_onset_25.append(recall_onset[0])
        list_recall_onset_5.append(recall_onset[1])
        list_F1_onset_25.append(F1_onset[0])
        list_F1_onset_5.append(F1_onset[1])
        list_precision_25.append(precision[0])
        list_precision_5.append(precision[1])
        list_recall_25.append(recall[0])
        list_recall_5.append(recall[1])
        list_F1_25.append(F1[0])
        list_F1_5.append(F1[1])

    return list_precision_onset_25, \
           list_recall_onset_25, \
           list_F1_onset_25, \
           list_precision_25, \
           list_recall_25, \
           list_F1_25, \
           list_precision_onset_5, \
           list_recall_onset_5, \
           list_F1_onset_5, \
           list_precision_5, \
           list_recall_5, \
           list_F1_5
コード例 #3
0
def viterbi_subroutine(test_nacta_2017, test_nacta, eval_label, obs_cal,
                       len_seq, model_name, architecture, scaler,
                       full_path_model, detection_results_path):
    """routine for viterbi decoding"""

    list_recall_onset_25, list_precision_onset_25, list_F1_onset_25 = [], [], []
    list_recall_onset_5, list_precision_onset_5, list_F1_onset_5 = [], [], []
    list_recall_25, list_precision_25, list_F1_25 = [], [], []
    list_recall_5, list_precision_5, list_F1_5 = [], [], []
    for ii in range(5):

        if obs_cal == 'tocal':

            stateful = False if varin['overlap'] else True
            input_shape = (1, len_seq, 1, 80, 15)

            # initialize the model
            model_keras_cnn_0 = jan_original(filter_density=1,
                                             dropout=0.5,
                                             input_shape=input_shape,
                                             batchNorm=False,
                                             dense_activation='sigmoid',
                                             channel=1,
                                             stateful=stateful,
                                             training=False,
                                             bidi=varin['bidi'])

            # load the model weights
            model_keras_cnn_0.load_weights(full_path_model + str(ii) + '.h5')

            # delete detection results path if it exists
            detection_results_path_model = join(detection_results_path +
                                                str(ii))
            if os.path.exists(detection_results_path_model) and os.path.isdir(
                    detection_results_path + str(ii)):
                shutil.rmtree(detection_results_path + str(ii))

            if varin['dataset'] != 'ismir':
                # evaluate nacta 2017 data set
                batch_process_onset_detection(
                    wav_path=nacta2017_wav_path,
                    textgrid_path=nacta2017_textgrid_path,
                    score_path=nacta2017_score_unified_path,
                    test_recordings=test_nacta_2017,
                    model_keras_cnn_0=model_keras_cnn_0,
                    len_seq=len_seq,
                    cnnModel_name=model_name + str(ii),
                    detection_results_path=detection_results_path + str(ii),
                    scaler=scaler,
                    obs_cal=obs_cal,
                    decoding_method='viterbi',
                    architecture=architecture,
                    stateful=stateful)

            # evaluate nacta dataset
            eval_results_decoding_path = \
                batch_process_onset_detection(wav_path=nacta_wav_path,
                                              textgrid_path=nacta_textgrid_path,
                                              score_path=nacta_score_unified_path,
                                              test_recordings=test_nacta,
                                              model_keras_cnn_0=model_keras_cnn_0,
                                              cnnModel_name=model_name + str(ii),
                                              detection_results_path=detection_results_path + str(ii),
                                              scaler=scaler,
                                              obs_cal=obs_cal,
                                              decoding_method='viterbi',
                                              architecture=architecture,
                                              stateful=stateful,
                                              len_seq=len_seq)
        else:
            eval_results_decoding_path = detection_results_path + str(ii)

        precision_onset, recall_onset, F1_onset, \
        precision, recall, F1, \
            = eval_write_2_txt(eval_result_file_name=join(eval_results_decoding_path, 'results.csv'),
                               segSyllable_path=eval_results_decoding_path,
                               label=eval_label,
                               decoding_method='viterbi')

        list_precision_onset_25.append(precision_onset[0])
        list_precision_onset_5.append(precision_onset[1])
        list_recall_onset_25.append(recall_onset[0])
        list_recall_onset_5.append(recall_onset[1])
        list_F1_onset_25.append(F1_onset[0])
        list_F1_onset_5.append(F1_onset[1])
        list_precision_25.append(precision[0])
        list_precision_5.append(precision[1])
        list_recall_25.append(recall[0])
        list_recall_5.append(recall[1])
        list_F1_25.append(F1[0])
        list_F1_5.append(F1[1])

    return list_precision_onset_25, \
           list_recall_onset_25, \
           list_F1_onset_25, \
           list_precision_25, \
           list_recall_25, \
           list_F1_25, \
           list_precision_onset_5, \
           list_recall_onset_5, \
           list_F1_onset_5, \
           list_precision_5, \
           list_recall_5, \
           list_F1_5
def viterbi_subroutine(test_nacta_2017, test_nacta, eval_label, obs_cal,
                       architecture, model_name, full_path_model,
                       full_path_scaler, detection_results_path):
    """5 run times routine for the viterbi decoding onset detection"""

    list_recall_onset_25, list_precision_onset_25, list_F1_onset_25 = [], [], []
    list_recall_onset_5, list_precision_onset_5, list_F1_onset_5 = [], [], []
    list_recall_25, list_precision_25, list_F1_25 = [], [], []
    list_recall_5, list_precision_5, list_F1_5 = [], [], []
    for ii in range(5):

        if obs_cal == 'tocal':

            if 'pretrained' in architecture:
                scaler = cPickle.load(
                    gzip.open(full_path_scaler + str(ii) + '.pickle.gz'))
            else:
                scaler = pickle.load(open(full_path_scaler))

            model_keras_cnn_0 = load_model(full_path_model + str(ii) + '.h5')
            # print(model_keras_cnn_0.summary())
            print('Model name:', full_path_model)

            # delete detection results path if it exists
            detection_results_path_model = join(detection_results_path +
                                                str(ii))
            if os.path.exists(detection_results_path_model) and os.path.isdir(
                    detection_results_path + str(ii)):
                shutil.rmtree(detection_results_path + str(ii))

            if varin['dataset'] != 'ismir':
                # nacta2017
                batch_process_onset_detection(
                    wav_path=nacta2017_wav_path,
                    textgrid_path=nacta2017_textgrid_path,
                    score_path=nacta2017_score_unified_path,
                    test_recordings=test_nacta_2017,
                    model_keras_cnn_0=model_keras_cnn_0,
                    cnnModel_name=model_name + str(ii),
                    detection_results_path=detection_results_path + str(ii),
                    scaler=scaler,
                    architecture=architecture,
                    obs_cal=obs_cal,
                    decoding_method='viterbi')

            # nacta
            eval_results_decoding_path = \
                batch_process_onset_detection(wav_path=nacta_wav_path,
                                              textgrid_path=nacta_textgrid_path,
                                              score_path=nacta_score_unified_path,
                                              test_recordings=test_nacta,
                                              model_keras_cnn_0=model_keras_cnn_0,
                                              cnnModel_name=model_name + str(ii),
                                              detection_results_path=detection_results_path + str(ii),
                                              scaler=scaler,
                                              architecture=architecture,
                                              obs_cal=obs_cal,
                                              decoding_method='viterbi')
        else:
            eval_results_decoding_path = detection_results_path + str(ii)

        precision_onset, recall_onset, F1_onset, \
        precision, recall, F1, \
            = eval_write_2_txt(eval_result_file_name=join(eval_results_decoding_path, 'results.csv'),
                               segSyllable_path=eval_results_decoding_path,
                               label=eval_label,
                               decoding_method='viterbi')

        list_precision_onset_25.append(precision_onset[0])
        list_precision_onset_5.append(precision_onset[1])
        list_recall_onset_25.append(recall_onset[0])
        list_recall_onset_5.append(recall_onset[1])
        list_F1_onset_25.append(F1_onset[0])
        list_F1_onset_5.append(F1_onset[1])
        list_precision_25.append(precision[0])
        list_precision_5.append(precision[1])
        list_recall_25.append(recall[0])
        list_recall_5.append(recall[1])
        list_F1_25.append(F1[0])
        list_F1_5.append(F1[1])

    return list_precision_onset_25, \
           list_recall_onset_25, \
           list_F1_onset_25, \
           list_precision_25, \
           list_recall_25, \
           list_F1_25, \
           list_precision_onset_5, \
           list_recall_onset_5, \
           list_F1_onset_5, \
           list_precision_5, \
           list_recall_5, \
           list_F1_5