Beispiel #1
0
def process_score(scores):
    min_score = 0
    res = []
    for (window, score) in scores:
        if score > min_score:
            res.append([window.end, score[0]])
            min_score = score[0]
    return res


def process_trial(trial, scores):
    res = {}
    pscores = process_score(scores)
    res['uri'] = trial['uri']
    res['model_id'] = trial['model_id']
    res['scores'] = pscores
    return res


from pyannote.metrics.spotting import LowLatencySpeakerSpotting
metric = LowLatencySpeakerSpotting(thresholds=np.linspace(0, 2, 50))
llss = []
for current_trial in protocol.development_trial():
    reference = current_trial.pop('reference')
    hypothesis = speaker_spotting_try(current_trial)
    llss.append(process_trial(current_trial, hypothesis))
    metric(reference, hypothesis)

import simplejson as json
with open('ivector_dev_no_clustering_pyannote_sad.json', 'w') as outfile:
    json.dump(llss, outfile)
Beispiel #2
0
def spotting(protocol,
             subset,
             latencies,
             hypotheses,
             output_prefix,
             filter_func=None):
    if not latencies:
        Scores = []

    protocol.diarization = False

    trials = getattr(protocol, '{subset}_trial'.format(subset=subset))()
    for i, (current_trial, hypothesis) in enumerate(zip(trials, hypotheses)):

        # check trial/hypothesis target consistency
        try:
            assert current_trial['model_id'] == hypothesis['model_id']
        except AssertionError as e:
            msg = ('target mismatch in trial #{i} '
                   '(found: {found}, should be: {should_be})')
            raise ValueError(
                msg.format(i=i,
                           found=hypothesis['model_id'],
                           should_be=current_trial['model_id']))

        # check trial/hypothesis file consistency
        try:
            assert current_trial['uri'] == hypothesis['uri']
        except AssertionError as e:
            msg = ('file mismatch in trial #{i} '
                   '(found: {found}, should be: {should_be})')
            raise ValueError(
                msg.format(i=i,
                           found=hypothesis['uri'],
                           should_be=current_trial['uri']))

        # check at least one score is provided
        try:
            assert len(hypothesis['scores']) > 0
        except AssertionError as e:
            msg = ('empty list of scores in trial #{i}.')
            raise ValueError(msg.format(i=i))

        timestamps, scores = zip(*hypothesis['scores'])

        if not latencies:
            Scores.append(scores)

        # check trial/hypothesis timerange consistency
        try_with = current_trial['try_with']
        try:
            assert min(timestamps) >= try_with.start
        except AssertionError as e:
            msg = ('incorrect timestamp in trial #{i} '
                   '(found: {found:g}, should be: >= {should_be:g})')
            raise ValueError(
                msg.format(i=i,
                           found=min(timestamps),
                           should_be=try_with.start))

    if not latencies:
        # estimate best set of thresholds
        scores = np.concatenate(Scores)
        epsilons = np.array(
            [n * 10**(-e) for e in range(4, 1, -1) for n in range(1, 10)])
        percentile = np.concatenate(
            [epsilons,
             np.arange(0.1, 100., 0.1), 100 - epsilons[::-1]])
        thresholds = np.percentile(scores, percentile)

    if not latencies:
        metric = LowLatencySpeakerSpotting(thresholds=thresholds)

    else:
        metric = LowLatencySpeakerSpotting(latencies=latencies)

    trials = getattr(protocol, '{subset}_trial'.format(subset=subset))()
    for i, (current_trial, hypothesis) in enumerate(zip(trials, hypotheses)):

        if filter_func is not None:
            speech = current_trial['reference'].duration()
            target_trial = speech > 0
            if target_trial and filter_func(speech):
                continue

        reference = current_trial['reference']
        metric(reference, hypothesis['scores'])

    if not latencies:

        thresholds, fpr, fnr, eer, _ = metric.det_curve(return_latency=False)

        # save DET curve to hypothesis.det.txt
        det_path = '{output_prefix}.det.txt'.format(
            output_prefix=output_prefix)
        det_tmpl = '{t:.9f} {p:.9f} {n:.9f}\n'
        with open(det_path, mode='w') as fp:
            fp.write('# threshold false_positive_rate false_negative_rate\n')
            for t, p, n in zip(thresholds, fpr, fnr):
                line = det_tmpl.format(t=t, p=p, n=n)
                fp.write(line)

        print('> {det_path}'.format(det_path=det_path))

        thresholds, fpr, fnr, _, _, speaker_lcy, absolute_lcy = \
            metric.det_curve(return_latency=True)

        # save DET curve to hypothesis.det.txt
        lcy_path = '{output_prefix}.lcy.txt'.format(
            output_prefix=output_prefix)
        lcy_tmpl = '{t:.9f} {p:.9f} {n:.9f} {s:.6f} {a:.6f}\n'
        with open(lcy_path, mode='w') as fp:
            fp.write(
                '# threshold false_positive_rate false_negative_rate speaker_latency absolute_latency\n'
            )
            for t, p, n, s, a in zip(thresholds, fpr, fnr, speaker_lcy,
                                     absolute_lcy):
                if p == 1:
                    continue
                if np.isnan(s):
                    continue
                line = lcy_tmpl.format(t=t, p=p, n=n, s=s, a=a)
                fp.write(line)

        print('> {lcy_path}'.format(lcy_path=lcy_path))

        print()
        print('EER% = {eer:.2f}'.format(eer=100 * eer))

    else:

        results = metric.det_curve()
        logs = []
        for key in sorted(results):

            result = results[key]
            log = {'latency': key}
            for latency in latencies:
                thresholds, fpr, fnr, eer, _ = result[latency]
                # print('EER @ {latency}s = {eer:.2f}%'.format(latency=latency,
                #                                             eer=100 * eer))
                log[latency] = eer
                # save DET curve to hypothesis.det.{lcy}s.txt
                det_path = '{output_prefix}.det.{key}.{latency:g}s.txt'.format(
                    output_prefix=output_prefix, key=key, latency=latency)
                det_tmpl = '{t:.9f} {p:.9f} {n:.9f}\n'
                with open(det_path, mode='w') as fp:
                    fp.write(
                        '# threshold false_positive_rate false_negative_rate\n'
                    )
                    for t, p, n in zip(thresholds, fpr, fnr):
                        line = det_tmpl.format(t=t, p=p, n=n)
                        fp.write(line)
            logs.append(log)
            det_path = '{output_prefix}.det.{key}.XXs.txt'.format(
                output_prefix=output_prefix, key=key)
            print('> {det_path}'.format(det_path=det_path))

        print()
        df = 100 * pd.DataFrame.from_dict(logs).set_index('latency')[latencies]
        print(
            tabulate(df,
                     tablefmt="simple",
                     headers=['latency'] +
                     ['EER% @ {l:g}s'.format(l=l) for l in latencies],
                     floatfmt=".2f",
                     numalign="decimal",
                     stralign="left",
                     missingval="",
                     showindex="default",
                     disable_numparse=False))
Beispiel #3
0
    return res


def process_trial(trial, scores):
    res = {}
    pscores = process_score(scores)
    res['uri'] = trial['uri']
    res['model_id'] = trial['model_id']
    res['scores'] = pscores
    return res


# In[15]:

from pyannote.metrics.spotting import LowLatencySpeakerSpotting
metric = LowLatencySpeakerSpotting(thresholds=np.linspace(0, 2, 50))
llss = []
for current_trial in protocol.test_trial():
    reference = current_trial.pop('reference')
    hypothesis = speaker_spotting_try_system2(current_trial)
    llss.append(process_trial(current_trial, hypothesis))
    metric(reference, hypothesis)

# In[ ]:

import simplejson as json
with open('llss.txt', 'w') as outfile:
    json.dump(llss, outfile)

# ## Results