예제 #1
0
파일: validation.py 프로젝트: swshon/sidnet
    def __init__(self, args):
        validation_data_name = args.data_root + '/' + args.test_data
        trials_filename = args.data_root + '/' + args.test_trials

        # Load validation trials
        tst_trials = []
        tst_enrolls = []
        tst_tests = []

        for line in open(trials_filename, 'r').readlines():
            row = line.split()
            tst_enrolls.append(row[0])
            tst_tests.append(row[1])
            tst_trials.append(np.where(row[2] == 'target', 1, 0))
        tst_trials = np.array(tst_trials)

        wavlist, utt_label, spk_label = kd.read_data_list(validation_data_name,
                                                          utt2spk=True,
                                                          utt2lang=False)
        tst_dict = dict()
        tst_list = []
        for value, key in enumerate(utt_label):
            tst_dict[key] = value

        tst_enrolls_idx = np.array([], dtype=int)
        tst_tests_idx = np.array([], dtype=int)
        for index in range(len(tst_enrolls)):
            tst_enrolls_idx = np.append(tst_enrolls_idx,
                                        int(tst_dict[tst_enrolls[index]]))
            tst_tests_idx = np.append(tst_tests_idx,
                                      int(tst_dict[tst_tests[index]]))

        self.wavlist = wavlist
        self.tst_enrolls_idx = tst_enrolls_idx
        self.tst_tests_idx = tst_tests_idx
        self.data = validation_data_name
        self.tst_trials = tst_trials
        self.args = args
예제 #2
0
파일: split_data.py 프로젝트: swshon/sidnet
                    help="source data")
parser.add_argument("--split",
                    type=str,
                    default="result_test_sorted.csv",
                    help="target data")
parser.add_argument("--utt2spk", action='store_true', help="for utt2spk file")
parser.add_argument("--utt2lang",
                    action='store_true',
                    help="for utt2lang file")
args = parser.parse_known_args()[0]

SOURCE_FOLDER = args.source
TOTAL_SPLIT = int(args.split)

if args.utt2spk:
    wavlist, utt_label, spk_label = kd.read_data_list(SOURCE_FOLDER,
                                                      utt2spk=True)
if args.utt2lang:
    wavlist, utt_label, lang_label = kd.read_data_list(SOURCE_FOLDER,
                                                       utt2lang=True)

if args.utt2spk:
    kd.split_data(SOURCE_FOLDER,
                  wavlist,
                  utt_label,
                  spk_label=spk_label,
                  total_split=TOTAL_SPLIT)

if args.utt2lang:
    kd.split_data(SOURCE_FOLDER,
                  wavlist,
                  utt_label,
예제 #3
0
파일: merge_data.py 프로젝트: swshon/sidnet
parser.add_argument("--utt2lang",
                    action='store_true',
                    help="for utt2lang file")
args = parser.parse_known_args()[0]

wavlist = []
utt_label = []
spk_label = []
lang_label = []

for name in [args.source1, args.source2]:
    print name
    if args.utt2spk:
        if args.utt2lang:
            wav, utt, spk, lang = kd.read_data_list(name,
                                                    utt2spk=True,
                                                    utt2lang=True)
            lang_label.extend(lang)
        else:
            wav, utt, spk = kd.read_data_list(name, utt2spk=True)
        spk_label.extend(spk)
    elif args.utt2lang:
        wav, utt, lang = kd.read_data_list(name, utt2lang=True)
        lang_label.extend(lang)

    wavlist.extend(wav)
    utt_label.extend(utt)

# if args.utt2lang:
#     wavlist,utt_label,lang_label = kd.read_data_list(args.source1,utt2lang=True)
#
예제 #4
0
                    type=str,
                    default="data/test_segments/utt2lang_sorted",
                    help="source data")
parser.add_argument("--target",
                    type=str,
                    default="result_test_sorted.csv",
                    help="target data")
parser.add_argument("--utt2spk", action='store_true', help="for utt2spk file")
parser.add_argument("--utt2lang",
                    action='store_true',
                    help="for utt2lang file")
args = parser.parse_known_args()[0]

SOURCE_FOLDER = args.source
TARGET_FOLDER = args.target

if not os.path.exists(TARGET_FOLDER):
    os.mkdir(TARGET_FOLDER)

wavlist, utt_label, spk_label = kd.read_data_list(SOURCE_FOLDER,
                                                  utt2spk=args.utt2spk,
                                                  utt2lang=args.utt2lang)

idx = range(len(wavlist))
np.random.shuffle(idx)

wavlist = wavlist[idx]
utt_label = utt_label[idx]
spk_label = spk_label[idx]
kd.write_data(TARGET_FOLDER, wavlist, utt_label, spk_label)
import sys
sys.path.insert(0,'scripts/')
import kaldi_data as kd

BASE_FOLDER = sys.argv[1]
TOTAL_SPLIT = int(sys.argv[2])

wavlist,utt_label,lang_label = kd.read_data_list(BASE_FOLDER,utt2lang=True)
kd.split_data(BASE_FOLDER,wavlist,utt_label,lang_label=lang_label,total_split=TOTAL_SPLIT)



예제 #6
0
WIN_LENGTH = int(args.win_len)
# FIXED_LEN = int(args.fixed_len) #298
SOFTMAX_NUM = args.softmax_num
RESUME_STARTPOINT = args.resume_startpoint
NN_MODEL = args.model_name
EMBEDDING_LAYER = args.embedding_layer

if VAD =='False':
    VAD = False
if CMVN == 'False':
    CMVN = False
is_batchnorm = True

if not args.segments_format:
    if int(TOTAL_SPLIT)==1:
        wavlist,utt_label,spk_label = kd.read_data_list(DATA_FOLDER, utt2spk=True)
    else:
        wavlist,utt_label,spk_label = kd.read_data_list(DATA_FOLDER+'/split'+TOTAL_SPLIT+'/'+CURRRENT_SPLIT, utt2spk=True)
    feat, _, utt_shape, tffilename = ft.feat_extract(wavlist,FEAT_TYPE,N_FFT,HOP,VAD,CMVN,EXCLUDE_SHORT)
else:
    if int(TOTAL_SPLIT)==1:
        wavlist,utt_label,seg_wavlist,seg_segid,seg_uttid,seg_windows = kd.read_data_list(DATA_FOLDER, utt2spk=False,segments=True)
    else:
        wavlist,utt_label,seg_wavlist,seg_segid,seg_uttid,seg_windows = kd.read_data_list(DATA_FOLDER+'/split'+TOTAL_SPLIT+'/'+CURRRENT_SPLIT, utt2spk=False,segments=True)
    feat, _, utt_shape, tffilename = ft.feat_extract(seg_wavlist,FEAT_TYPE,N_FFT,HOP,VAD,CMVN,EXCLUDE_SHORT,seg_windows=seg_windows)


SAVER_FOLDERNAME = 'saver/'+NN_MODEL+'_'+tffilename
nn_model = __import__(NN_MODEL)

x = tf.placeholder(tf.float32, [None,None,FEAT_DIM])
예제 #7
0
import numpy as np
import os, sys
sys.path.insert(0, 'scripts/')
import kaldi_data as kd
import argparse
parser = argparse.ArgumentParser(description="Shuttling data", add_help=True)
parser.add_argument("--source",
                    type=str,
                    default="data/voxceleb1_dev",
                    help="source data")
parser.add_argument("--target",
                    type=str,
                    default="data/voxceleb1_dev_1utt",
                    help="target data")
args = parser.parse_known_args()[0]

if not os.path.exists(args.target):
    os.mkdir(args.target)

wavlist, utt_label, spk_label = kd.read_data_list(args.source,
                                                  utt2spk=True,
                                                  utt2lang=False)
_, idx = np.unique(spk_label, return_index=True)

wavlist = wavlist[idx]
utt_label = utt_label[idx]
spk_label = spk_label[idx]
kd.write_data(args.target, wavlist, utt_label, spk_label)