def runProcess(val_test, plot):
    model_keras_cnn_0 = load_model(kerasModels_path)

    # open a pickle from python 2 in python 3, requires to add encoding
    scaler = pickle.load(open(kerasScaler_path, 'rb'), encoding='latin1')

    # the test dataset filenames
    primarySchool_val_recordings, primarySchool_test_recordings = getTestRecordingsJoint()

    if val_test == 'val':
        recordings = primarySchool_val_recordings
    else:
        recordings = primarySchool_test_recordings

    dict_total = {}
    dict_head = {}
    dict_belly = {}

    dict_feature_phns_total = {}
    dict_feature_phns_head = {}
    dict_feature_phns_belly = {}

    for artist, fn in recordings:

        # teacher's textgrid file
        teacher_textgrid_file = os.path.join(primarySchool_textgrid_path, artist, 'teacher.TextGrid')
        # textgrid path, to get the line onset offset
        student_textgrid_file = os.path.join(primarySchool_textgrid_path, artist, fn + '.TextGrid')

        # parse the textgrid to phoneme list
        teacherSyllableLists, teacherPhonemeLists = textgridSyllablePhonemeParser(teacher_textgrid_file,
                                                                              'dianSilence',
                                                                              'details')
        studentSyllableLists, studentPhonemeLists = textgridSyllablePhonemeParser(student_textgrid_file,
                                                                        'dianSilence',
                                                                        'details')

        student_wav_file = os.path.join(primarySchool_wav_path, artist, fn + '.wav')

        # calculate log mel
        log_mel = getMFCCBands2DMadmom(student_wav_file, fs, hopsize_t, channel=1)
        log_mel_scaled = scaler.transform(log_mel)
        log_mel_reshaped = featureReshape(log_mel_scaled, nlen=7)

        if artist not in dict_total:
            dict_total[artist] = {}
            dict_head[artist] = {}
            dict_belly[artist] = {}

            dict_feature_phns_total[artist] = {}
            dict_feature_phns_head[artist] = {}
            dict_feature_phns_belly[artist] = {}

        for ii_line in range(len(studentPhonemeLists)): # iterate each line

            # find the right line index for the teacher's textgrid,
            # ``student02_first_half'' only corresponds to a part of the teacher's textgrid,
            # we need to shift the index of the teacher's textgrid to find the right line
            ii_aug = findShiftOffset(gtSyllableLists=studentSyllableLists,
                                     scoreSyllableLists=teacherSyllableLists,
                                     ii_line=ii_line)

            list_phn_teacher, list_phn_student, list_syl_teacher, list_syl_onsets_time_teacher = \
                getListsSylPhn(teacherSyllableLists=teacherSyllableLists,
                               teacherPhonemeLists=teacherPhonemeLists,
                               studentPhonemeLists=studentPhonemeLists,
                               ii_line=ii_line,
                               ii_aug=ii_aug)


            phns_teacher = [lpt[2] for lpt in list_phn_teacher]
            phns_student = [lpt[2] for lpt in list_phn_student]

            insertion_indices_student, deletion_indices_teacher, teacher_student_indices_pair, dict_student_idx_2_teacher_phn = \
                phnSequenceAlignment(phns_teacher=phns_teacher, phns_student=phns_student)

            list_phn_teacher_pair, list_phn_student_pair, idx_syl_heads, phn_tails_missing, num_tails_missing = \
                getIdxHeadsMissingTails(teacher_student_indices_pair=teacher_student_indices_pair,
                                        list_phn_teacher=list_phn_teacher,
                                        list_phn_student=list_phn_student,
                                        list_syl_onsets_time_teacher=list_syl_onsets_time_teacher,
                                        deletion_indices_teacher=deletion_indices_teacher,
                                        phns_tails=phns_tails)

            print('these phone indices are inserted in student phone list', insertion_indices_student)
            print('these phone indices are deleted in teacher phone list', deletion_indices_teacher)
            print('these phone tails are deleted in teacher phone list', phn_tails_missing)

            obs_line = getObsLine(studentPhonemeLists=studentPhonemeLists,
                                   ii_line=ii_line,
                                   hopsize_t=hopsize_t,
                                   log_mel_reshaped=log_mel_reshaped,
                                   model_keras_cnn_0=model_keras_cnn_0)

            GOP_line = []
            for ii_phn in range(len(list_phn_student_pair)):

                phn_start_frame = int(round((list_phn_student_pair[ii_phn][0] - list_phn_student_pair[0][0]) / hopsize_t))
                phn_end_frame = int(round((list_phn_student_pair[ii_phn][1] - list_phn_student_pair[0][0]) / hopsize_t))

                phn_label = list_phn_teacher_pair[ii_phn][2]

                # the case of the phn length is 0
                if phn_end_frame == phn_start_frame:
                    GOP_line.append([ii_phn, -np.inf, phn_label])
                    continue

                obs_line_phn = obs_line[phn_start_frame:phn_end_frame]

                # if plot:
                #     figurePlot(obs_line_phn.T)

                # calculate GOP
                GOP_phn = GOP_phn_level(phn_label=phn_label, obs_line_phn=obs_line_phn)
                GOP_line.append([ii_phn, GOP_phn, phn_label])

            # print(len(GOP_line), idx_syl_heads)
            gop_total = [gop[1] for gop in GOP_line if not np.isinf(gop[1])]
            gop_head = [gop[1] for gop in GOP_line if not np.isinf(gop[1]) and gop[0] in idx_syl_heads]
            gop_belly = [gop[1] for gop in GOP_line if not np.isinf(gop[1]) and gop[0] not in idx_syl_heads]

            if plot:
                disLinePlot(gop_total, [gop[2] for gop in GOP_line if not np.isinf(gop[1])])
                disLinePlot(gop_head, [gop[2] for gop in GOP_line if not np.isinf(gop[1]) and gop[0] in idx_syl_heads])
                disLinePlot(gop_belly, [gop[2] for gop in GOP_line if not np.isinf(gop[1]) and gop[0] not in idx_syl_heads])


            total_distortion = np.mean(gop_total)
            head_distortion = np.mean(gop_head)
            belly_distortion = np.mean(gop_belly)

            dict_total[artist][fn + '_' + str(ii_line+ii_aug)] = total_distortion
            dict_head[artist][fn + '_' + str(ii_line+ii_aug)] = head_distortion
            dict_belly[artist][fn + '_' + str(ii_line+ii_aug)] = belly_distortion

            dict_feature_phns_total[artist][fn + '_' + str(ii_line+ii_aug)] = {'distortion_phns':np.array(gop_total), 'num_tails_missing':num_tails_missing}
            dict_feature_phns_head[artist][fn + '_' + str(ii_line+ii_aug)] = {'distortion_phns':np.array(gop_head), 'num_tails_missing':num_tails_missing}
            dict_feature_phns_belly[artist][fn + '_' + str(ii_line+ii_aug)] = {'distortion_phns':np.array(gop_belly), 'num_tails_missing':num_tails_missing}

    if val_test == 'test':
        with open('./data/rating_GOP_oracle_total.json', 'w') as savefile:
            json.dump(dict_total, savefile)
        with open('./data/rating_GOP_oracle_head.json', 'w') as savefile:
            json.dump(dict_head, savefile)
        with open('./data/rating_GOP_oracle_belly.json', 'w') as savefile:
            json.dump(dict_belly, savefile)

    with open('./data/training_features/GOP_oracle_'+val_test+'_total.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns_total, savefile)
    with open('./data/training_features/GOP_oracle_'+val_test+'_head.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns_head, savefile)
    with open('./data/training_features/GOP_oracle_'+val_test+'_belly.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns_belly, savefile)
Exemplo n.º 2
0
def runProcess(val_test, mode):
    # the test dataset filenames
    primarySchool_val_recordings, primarySchool_test_recordings = getTestRecordingsJoint()

    if val_test == 'val':
        recordings = primarySchool_val_recordings
    else:
        recordings = primarySchool_test_recordings

    dict_distortion = {}
    dict_feature_phns = {}

    for artist, fn in recordings:

        # teacher's textgrid file
        teacher_textgrid_file = os.path.join(primarySchool_textgrid_path, artist, 'teacher.TextGrid')
        # textgrid path, to get the line onset offset
        student_textgrid_file = os.path.join(primarySchool_textgrid_path, artist, fn + '.TextGrid')

        # parse the textgrid to phoneme list
        teacherSyllableLists, teacherPhonemeLists = textgridSyllablePhonemeParser(teacher_textgrid_file,
                                                                                  'dianSilence',
                                                                                  'details')
        studentSyllableLists, studentPhonemeLists = textgridSyllablePhonemeParser(student_textgrid_file,
                                                                                  'dianSilence',
                                                                                  'details')

        teacher_wav_file = os.path.join(primarySchool_wav_path, artist, 'teacher.wav')
        student_wav_file = os.path.join(primarySchool_wav_path, artist, fn + '.wav')

        print(teacher_wav_file)
        # calculate MFCCs
        # audio_data_teacher, fs_teacher = librosa.load(teacher_wav_file)
        # audio_data_student, fs_student = librosa.load(student_wav_file)

        audio_data_teacher, fs_teacher = sf.read(teacher_wav_file)
        audio_data_student, fs_student = sf.read(student_wav_file)

        audio_data_teacher = sterero2Mono(audio_data_teacher)
        audio_data_student = sterero2Mono(audio_data_student)

        # 39 dimensions mfcc
        mfccs_teacher = mfccDeltaDelta(audio_data=audio_data_teacher, fs=fs_teacher, framesize=framesize,
                                       hopsize=hopsize)
        mfccs_student = mfccDeltaDelta(audio_data=audio_data_student, fs=fs_student, framesize=framesize,
                                       hopsize=hopsize)

        # create the artist key
        if artist not in dict_distortion:
            dict_distortion[artist] = {}
            dict_feature_phns[artist] = {}

        for ii_line in range(len(studentPhonemeLists)):  # iterate each line

            # find the right line index for the teacher's textgrid,
            # ``student02_first_half'' only corresponds to a part of the teacher's textgrid,
            # we need to shift the index of the teacher's textgrid to find the right line
            ii_aug = findShiftOffset(gtSyllableLists=studentSyllableLists,
                                     scoreSyllableLists=teacherSyllableLists,
                                     ii_line=ii_line)

            # trim the mfccs line
            line_teacher = teacherPhonemeLists[ii_line + ii_aug][0]
            mfccs_teacher_line = segmentMfccLine(line=line_teacher, hopsize_t=hopsize_t, mfccs=mfccs_teacher)

            line_student = studentPhonemeLists[ii_line][0]
            mfccs_student_line = segmentMfccLine(line=line_student, hopsize_t=hopsize_t, mfccs=mfccs_student)

            list_phn_teacher, list_phn_student, list_syl_teacher, list_syl_onsets_time_teacher = \
                getListsSylPhn(teacherSyllableLists=teacherSyllableLists,
                               teacherPhonemeLists=teacherPhonemeLists,
                               studentPhonemeLists=studentPhonemeLists,
                               ii_line=ii_line,
                               ii_aug=ii_aug)

            phns_teacher = [lpt[2] for lpt in list_phn_teacher]
            phns_student = [lpt[2] for lpt in list_phn_student]

            insertion_indices_student, deletion_indices_teacher, teacher_student_indices_pair, _ = \
                phnSequenceAlignment(phns_teacher=phns_teacher, phns_student=phns_student)

            list_phn_teacher_pair, list_phn_student_pair, idx_syl_heads, phn_tails_missing, num_tails_missing = \
                getIdxHeadsMissingTails(teacher_student_indices_pair=teacher_student_indices_pair,
                                        list_phn_teacher=list_phn_teacher,
                                        list_phn_student=list_phn_student,
                                        list_syl_onsets_time_teacher=list_syl_onsets_time_teacher,
                                        deletion_indices_teacher=deletion_indices_teacher,
                                        phns_tails=phns_tails)

            print('these phone indices are inserted in student phone list', insertion_indices_student)
            print('these phone indices are deleted in teacher phone list', deletion_indices_teacher)
            print('these phone tails are deleted in teacher phone list', phn_tails_missing)

            mu_cov_teacher = []
            mu_cov_student = []

            for ii_phn_pair in range(len(list_phn_teacher_pair)):
                mu_teacher, cov_teacher, phn_label_teacher = gaussianPipeline(list_phn=list_phn_teacher_pair,
                                                                              ii=ii_phn_pair,
                                                                              hopsize_t=hopsize_t,
                                                                              mfccs_line=mfccs_teacher_line)
                mu_cov_teacher.append([mu_teacher, cov_teacher])

                mu_student, cov_student, phn_label_student = gaussianPipeline(list_phn=list_phn_student_pair,
                                                                              ii=ii_phn_pair,
                                                                              hopsize_t=hopsize_t,
                                                                              mfccs_line=mfccs_student_line)
                mu_cov_student.append([mu_student, cov_student])

            distance_mat_teacher = BDDistanceMat(mu_cov_teacher)
            distance_mat_student = BDDistanceMat(mu_cov_student)

            if mode == '_total':
                # remove the matrix row and col if containing nan
                distance_mat_teacher, distance_mat_student, _ = removeNanRowCol(distance_mat_teacher,
                                                                                distance_mat_student)


                # print((distance_mat_teacher - distance_mat_student).shape, distance_mat_teacher.shape[0])

                distortion = np.linalg.norm(distance_mat_teacher - distance_mat_student) / distance_mat_teacher.shape[0]

                # phone-level distortion
                distortion_phns = np.linalg.norm(distance_mat_teacher - distance_mat_student, axis=1) / np.sqrt(
                    distance_mat_teacher.shape[0])


            elif mode == '_head':  # only consider head phns
                print(idx_syl_heads)
                distance_head_mat_teacher = distance_mat_teacher[np.ix_(idx_syl_heads, idx_syl_heads)]
                distance_head_mat_student = distance_mat_student[np.ix_(idx_syl_heads, idx_syl_heads)]

                distance_head_mat_teacher, distance_head_mat_student, _ = \
                    removeNanRowCol(distance_head_mat_teacher, distance_head_mat_student)

                # print((distance_head_mat_teacher-distance_head_mat_student).shape, distance_head_mat_teacher.shape[0])

                distortion = np.linalg.norm(distance_head_mat_teacher - distance_head_mat_student) / distance_head_mat_teacher.shape[0]

                # phone-level distortion
                distortion_phns = np.linalg.norm(distance_head_mat_teacher - distance_head_mat_student, axis=1) / np.sqrt(
                    distance_head_mat_teacher.shape[0])

            elif mode == '_belly':
                idx_syl_belly = [ii_entire_idx for ii_entire_idx in range(distance_mat_teacher.shape[0]) if
                                 ii_entire_idx not in idx_syl_heads]

                distance_belly_mat_teacher = distance_mat_teacher[np.ix_(idx_syl_belly, idx_syl_belly)]
                distance_belly_mat_student = distance_mat_student[np.ix_(idx_syl_belly, idx_syl_belly)]

                distance_belly_mat_teacher, distance_belly_mat_student, _ = \
                    removeNanRowCol(distance_belly_mat_teacher, distance_belly_mat_student)

                distortion = np.linalg.norm(distance_belly_mat_teacher - distance_belly_mat_student) / distance_belly_mat_teacher.shape[0]

                # phone-level distortion
                distortion_phns = np.linalg.norm(distance_belly_mat_teacher - distance_belly_mat_student, axis=1) / np.sqrt(
                    distance_belly_mat_teacher.shape[0])

            if np.isnan(distortion):
                raise ValueError

            dict_distortion[artist][fn + '_' + str(ii_line + ii_aug)] = distortion
            dict_feature_phns[artist][fn + '_' + str(ii_line + ii_aug)] = {'distortion_phns':distortion_phns, 'num_tails_missing':num_tails_missing}

    if val_test == 'test':
        with open('./data/rating_SR_oracle' + mode + '.json', 'w') as savefile:
            json.dump(dict_distortion, savefile)

    with open('./data/training_features/SR_oracle_'+ val_test + mode + '.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns, savefile)
Exemplo n.º 3
0
    ax2.set_ylabel('VAD', fontsize=12)
    ax2.axis('tight')
    plt.show()


if __name__ == '__main__':

    plot = False

    model_keras_cnn_0 = load_model(kerasModels_path)

    # open a pickle from python 2 in python 3, requires to add encoding
    scaler = pickle.load(open(kerasScaler_path, 'rb'), encoding='latin1')

    # the test dataset filenames
    primarySchool_val_recordings, primarySchool_test_recordings = getTestRecordingsJoint(
    )

    for artist, fn in primarySchool_val_recordings + primarySchool_test_recordings:

        # textgrid path, to get the line onset offset
        groundtruth_textgrid_file = os.path.join(primarySchool_textgrid_path,
                                                 artist, fn + '.TextGrid')

        # parse the TextGrid
        list_line = textGrid2WordList(groundtruth_textgrid_file,
                                      whichTier='line')

        wav_file = os.path.join(primarySchool_wav_path, artist, fn + '.wav')

        vad_results = VAD(wav_file)
Exemplo n.º 4
0
        for ii, line_list in enumerate(nestedLists):
            print(artist_path, recording ,ii, len(line_list[1]))

            if childTierName=='details':
                for phn in line_list[1]:
                    try:
                        key = dic_pho_map[phn[2]]
                    except:
                        print(artist_path, ii, recording, phn[2])
                        raise KeyError

if __name__ == '__main__':
    # check line contains a reasonable syllable or phoneme number
    valPrimarySchool, testPrimarySchool \
        = getTestRecordingsJoint()


    s_check(textgrid_path=primarySchool_textgrid_path,
            parentTierName='line',
            childTierName='dianSilence',
            recordings=testPrimarySchool)

    s_check(textgrid_path=primarySchool_textgrid_path,
            parentTierName='line',
            childTierName='details',
            recordings=testPrimarySchool)



def runProcess(val_test, plot):
    # load model weights
    model_keras_cnn_0 = load_model(kerasModels_emb_cla_path)
    weights = model_keras_cnn_0.get_weights()

    input_shape = [1, None, 80]
    model_keras_cnn_0 = model_select(config=[1, 1], input_shape=input_shape)
    model_keras_cnn_0.compile(optimizer='adam',
                              loss='categorical_crossentropy',
                              metrics=['accuracy'])
    model_keras_cnn_0.set_weights(weights=weights)

    # open a pickle from python 2 in python 3, requires to add encoding
    scaler = pickle.load(open(kerasScaler_emb_cla_path, 'rb'),
                         encoding='latin1')

    # the test dataset filenames
    primarySchool_val_recordings, primarySchool_test_recordings = getTestRecordingsJoint(
    )

    if val_test == 'val':
        recordings = primarySchool_val_recordings
    else:
        recordings = primarySchool_test_recordings

    dict_total = {}
    dict_head = {}
    dict_belly = {}

    dict_feature_phns_total = {}
    dict_feature_phns_head = {}
    dict_feature_phns_belly = {}

    for artist, fn in recordings:

        # teacher's textgrid file
        teacher_textgrid_file = os.path.join(primarySchool_textgrid_path,
                                             artist, 'teacher.TextGrid')
        # textgrid path, to get the line onset offset
        student_textgrid_file = os.path.join(primarySchool_textgrid_path,
                                             artist, fn + '.TextGrid')

        # parse the textgrid to phoneme list
        teacherSyllableLists, teacherPhonemeLists = textgridSyllablePhonemeParser(
            teacher_textgrid_file, 'dianSilence', 'details')
        studentSyllableLists, studentPhonemeLists = textgridSyllablePhonemeParser(
            student_textgrid_file, 'dianSilence', 'details')

        teacher_wav_file = os.path.join(primarySchool_wav_path, artist,
                                        'teacher.wav')
        student_wav_file = os.path.join(primarySchool_wav_path, artist,
                                        fn + '.wav')

        # calculate log mel
        log_mel_teacher = getMFCCBandsMadmom(teacher_wav_file, fs, hopsize_t)
        log_mel_scaled_teacher = scaler.transform(log_mel_teacher)

        log_mel_student = getMFCCBandsMadmom(student_wav_file, fs, hopsize_t)
        log_mel_scaled_student = scaler.transform(log_mel_student)

        if artist not in dict_total:
            dict_total[artist] = {}
            dict_head[artist] = {}
            dict_belly[artist] = {}

            dict_feature_phns_total[artist] = {}
            dict_feature_phns_head[artist] = {}
            dict_feature_phns_belly[artist] = {}

        for ii_line in range(len(studentPhonemeLists)):  # iterate each line

            # find the right line index for the teacher's textgrid,
            # ``student02_first_half'' only corresponds to a part of the teacher's textgrid,
            # we need to shift the index of the teacher's textgrid to find the right line
            ii_aug = findShiftOffset(gtSyllableLists=studentSyllableLists,
                                     scoreSyllableLists=teacherSyllableLists,
                                     ii_line=ii_line)

            list_phn_teacher, list_phn_student, list_syl_teacher, list_syl_onsets_time_teacher = \
                getListsSylPhn(teacherSyllableLists=teacherSyllableLists,
                               teacherPhonemeLists=teacherPhonemeLists,
                               studentPhonemeLists=studentPhonemeLists,
                               ii_line=ii_line,
                               ii_aug=ii_aug)

            phns_teacher = [lpt[2] for lpt in list_phn_teacher]
            phns_student = [lpt[2] for lpt in list_phn_student]

            insertion_indices_student, deletion_indices_teacher, teacher_student_indices_pair, dict_student_idx_2_teacher_phn = \
                phnSequenceAlignment(phns_teacher=phns_teacher, phns_student=phns_student)

            list_phn_teacher_pair, list_phn_student_pair, idx_syl_heads, phn_tails_missing, num_tails_missing = \
                getIdxHeadsMissingTails(teacher_student_indices_pair=teacher_student_indices_pair,
                                        list_phn_teacher=list_phn_teacher,
                                        list_phn_student=list_phn_student,
                                        list_syl_onsets_time_teacher=list_syl_onsets_time_teacher,
                                        deletion_indices_teacher=deletion_indices_teacher,
                                        phns_tails=phns_tails)

            print('these phone indices are inserted in student phone list',
                  insertion_indices_student)
            print('these phone indices are deleted in teacher phone list',
                  deletion_indices_teacher)
            print('these phone tails are deleted in teacher phone list',
                  phn_tails_missing)

            log_mel_reshaped_line_teacher = getLogMelLine(
                studentPhonemeLists=teacherPhonemeLists,
                ii_line=ii_line + ii_aug,
                hopsize_t=hopsize_t,
                log_mel_reshaped=log_mel_scaled_teacher)

            log_mel_reshaped_line_student = getLogMelLine(
                studentPhonemeLists=studentPhonemeLists,
                ii_line=ii_line,
                hopsize_t=hopsize_t,
                log_mel_reshaped=log_mel_scaled_student)

            cos_dis_line = []  # cosine dissimilarity
            for ii_phn in range(len(list_phn_student_pair)):

                phn_start_frame_teacher = int(
                    round((list_phn_teacher_pair[ii_phn][0] -
                           list_phn_teacher_pair[0][0]) / hopsize_t))
                phn_end_frame_teacher = int(
                    round((list_phn_teacher_pair[ii_phn][1] -
                           list_phn_teacher_pair[0][0]) / hopsize_t))

                phn_start_frame_student = int(
                    round((list_phn_student_pair[ii_phn][0] -
                           list_phn_student_pair[0][0]) / hopsize_t))
                phn_end_frame_student = int(
                    round((list_phn_student_pair[ii_phn][1] -
                           list_phn_student_pair[0][0]) / hopsize_t))

                phn_label = list_phn_teacher_pair[ii_phn][2]

                # the case of the phn length is 0
                if phn_end_frame_teacher == phn_start_frame_teacher or phn_end_frame_student == phn_start_frame_student:
                    cos_dis_line.append([ii_phn, 1.0, phn_label])
                    continue

                log_mel_line_phn_teacher = log_mel_reshaped_line_teacher[
                    phn_start_frame_teacher:phn_end_frame_teacher]
                log_mel_line_phn_student = log_mel_reshaped_line_student[
                    phn_start_frame_student:phn_end_frame_student]

                # calculate GOP
                cos_dis_phn = measureEmbDissimilarity(
                    model_keras_cnn_0=model_keras_cnn_0,
                    log_mel_phn_teacher=log_mel_line_phn_teacher,
                    log_mel_phn_student=log_mel_line_phn_student)
                cos_dis_line.append([ii_phn, cos_dis_phn, phn_label])

            # print(len(cos_dis_line), idx_syl_heads)
            dis_total = [dis[1] for dis in cos_dis_line]
            dis_head = [
                dis[1] for dis in cos_dis_line if dis[0] in idx_syl_heads
            ]
            dis_belly = [
                dis[1] for dis in cos_dis_line if dis[0] not in idx_syl_heads
            ]

            if plot:
                disLinePlot(dis_total, [dis[2] for dis in cos_dis_line])
                disLinePlot(dis_head, [
                    dis[2] for dis in cos_dis_line if dis[0] in idx_syl_heads
                ])
                disLinePlot(dis_belly, [
                    dis[2]
                    for dis in cos_dis_line if dis[0] not in idx_syl_heads
                ])

            total_distortion = np.mean(dis_total)
            head_distortion = np.mean(dis_head)
            belly_distortion = np.mean(dis_belly)

            dict_total[artist][fn + '_' +
                               str(ii_line + ii_aug)] = total_distortion
            dict_head[artist][fn + '_' +
                              str(ii_line + ii_aug)] = head_distortion
            dict_belly[artist][fn + '_' +
                               str(ii_line + ii_aug)] = belly_distortion

            dict_feature_phns_total[artist][fn + '_' +
                                            str(ii_line + ii_aug)] = {
                                                'distortion_phns':
                                                np.array(dis_total),
                                                'num_tails_missing':
                                                num_tails_missing
                                            }
            dict_feature_phns_head[artist][fn + '_' +
                                           str(ii_line + ii_aug)] = {
                                               'distortion_phns':
                                               np.array(dis_head),
                                               'num_tails_missing':
                                               num_tails_missing
                                           }
            dict_feature_phns_belly[artist][fn + '_' +
                                            str(ii_line + ii_aug)] = {
                                                'distortion_phns':
                                                np.array(dis_belly),
                                                'num_tails_missing':
                                                num_tails_missing
                                            }

    if val_test == 'test':
        with open('./data/rating_emb_cla_oracle_total.json', 'w') as savefile:
            json.dump(dict_total, savefile)
        with open('./data/rating_emb_cla_oracle_head.json', 'w') as savefile:
            json.dump(dict_head, savefile)
        with open('./data/rating_emb_cla_oracle_belly.json', 'w') as savefile:
            json.dump(dict_belly, savefile)

    with open(
            './data/training_features/emb_cla_oracle_' + val_test +
            '_total.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns_total, savefile)
    with open(
            './data/training_features/emb_cla_oracle_' + val_test +
            '_head.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns_head, savefile)
    with open(
            './data/training_features/emb_cla_oracle_' + val_test +
            '_belly.pkl', 'wb') as savefile:
        pickle.dump(dict_feature_phns_belly, savefile)