예제 #1
0
def main():

    if len(sys.argv) != 5:
        print(
            '<plda-out-domain> <adapt-ivector-rspecifier> <plda-in-domain> <plda-adapt> \n',
        )
        sys.exit()

    plda_out_domain = sys.argv[1]
    train_vecs_adapt = sys.argv[2]
    plda_in_domain = sys.argv[3]
    plda_adapt = sys.argv[4]

    coral = CORAL()
    coral.plda_read(plda_out_domain)

    for _, vec in kaldi_io.read_vec_flt_auto(train_vecs_adapt):
        coral.add_stats(1, vec)
    coral.update_plda()

    cipreg = CIPReg()
    cipreg.plda_read(plda_in_domain)
    cipreg.interpolation(coral)

    plda_new = PLDA()
    plda_new.mean = cipreg.mean
    plda_new.within_var = cipreg.within_var
    plda_new.between_var = cipreg.between_var
    plda_new.get_output()
    plda_new.plda_trans_write(plda_adapt)
예제 #2
0
def main():

    if len(sys.argv) != 4:

        print("Usage: " + sys.argv[0] +
              " <spk2utt-rspecifier> <ivector-rspecifier> <plda>\n")
        print("e.g.: " + sys.argv[0] + " spk2utt ivectors.ark plda")

        sys.exit()

    spk2utt = sys.argv[1]
    ivectors_reader = sys.argv[2]
    plda_out = sys.argv[3]

    logger.info('Load vecs and accumulate the stats of vecs.....')
    utt2spk_dict = {}
    with open(spk2utt, 'r') as f:
        for line in f:
            temp_list = line.strip().split()
            spk = temp_list[0]
            del temp_list[0]
            for utt in temp_list:
                utt2spk_dict[utt] = spk

    spk2vectors = {}
    for key, vector in kaldi_io.read_vec_flt_auto(ivectors_reader):
        dim = vector.shape[0]
        spk = utt2spk_dict[key]
        try:
            tmp_list = spk2vectors[spk]
            tmp_list.append(vector)
            spk2vectors[spk] = tmp_list
        except KeyError:
            spk2vectors[spk] = [vector]

    plda_stats = PldaStats(dim)
    for key in spk2vectors.keys():
        vectors = np.array(spk2vectors[key], dtype=float)
        weight = 1.0
        plda_stats.add_samples(weight, vectors)

    logger.info('Estimate the parameters of PLDA by EM algorithm...')
    plda_stats.sort()
    plda_estimator = PldaEstimation(plda_stats)
    plda_estimator.estimate()
    logger.info('Save the parameters for the PLDA adaptation...')
    plda_estimator.plda_write(plda_out + '.ori')
    plda_trans = plda_estimator.get_output()
    logger.info(
        'Save the parameters for scoring directly, which is the same with the plda in kaldi...'
    )
    plda_trans.plda_trans_write(plda_out)
def main():

    if len(sys.argv) != 4:
        print('<plda> <adapt-ivector-rspecifier> <plda-adapt> \n', )
        sys.exit()

    plda = sys.argv[1]
    train_vecs_adapt = sys.argv[2]
    plda_adapt = sys.argv[3]

    coralplus = CORALPlus()
    coralplus.plda_read(plda)

    for _, vec in kaldi_io.read_vec_flt_auto(train_vecs_adapt):
        coralplus.add_stats(1, vec)
    coralplus.update_plda()

    plda_new = PLDA()
    plda_new.mean = coralplus.mean
    plda_new.within_var = coralplus.within_var
    plda_new.between_var = coralplus.between_var
    plda_new.get_output()
    plda_new.plda_trans_write(plda_adapt)