def test_der():
    ref_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'ref.rttm'))
    sys_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'sys.rttm'))
    expected_der = 26.3931
    file_to_der, global_der = der(ref_turns, sys_turns)
    assert_almost_equal(file_to_der['FILE1'], expected_der, 3)
    assert_almost_equal(global_der, expected_der, 3)
Beispiel #2
0
def test_score():
    # Some real data.
    expected_scores = Scores('FILE1', 26.39309, 33.24631, 0.71880, 0.72958,
                             0.72415, 0.60075, 0.58534, 0.80471, 0.72543,
                             0.96810, 0.55872)
    ref_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'ref.rttm'))
    sys_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'sys.rttm'))
    uem = UEM({'FILE1': [(0, 43)]})
    file_scores, global_scores = score(ref_turns, sys_turns, uem)
    assert len(file_scores) == 1
    assert file_scores[-1].file_id == expected_scores.file_id
    assert_almost_equal(file_scores[-1][1:], expected_scores[1:], 3)
    expected_scores = expected_scores._replace(file_id='*** OVERALL ***')
    assert global_scores.file_id == expected_scores.file_id
    assert_almost_equal(global_scores[1:], expected_scores[1:], 3)
def test_jer():
    # Check input validation.
    with assert_raises_regex(ValueError, 'All passed dicts must have same'):
        jer(dict(), dict(F1=np.zeros(1)), dict())
        jer(dict(F1=np.zeros(1)), dict(), dict())
        jer(dict(), dict(), dict(F1=np.zeros(1)))

    # Edge case: no reference speech.
    ref_durs = np.array([], dtype='int64')
    sys_durs = np.array([5, 5], dtype='int64')
    cm = np.zeros((0, 2), dtype='int64')
    file_to_jer, global_jer = jer(dict(F=ref_durs), dict(F=sys_durs),
                                  dict(F=cm))
    assert file_to_jer['F'] == 100.
    assert global_jer == 100.

    # Edge case: no system speech.
    ref_durs = np.array([5, 5], dtype='int64')
    sys_durs = np.array([], dtype='int64')
    cm = np.zeros((2, 0), dtype='int64')
    file_to_jer, global_jer = jer(dict(F=ref_durs), dict(F=sys_durs),
                                  dict(F=cm))
    assert file_to_jer['F'] == 100.
    assert global_jer == 100.

    # Edge case: no reference OR system speech.
    ref_durs = np.array([], dtype='int64')
    sys_durs = np.array([], dtype='int64')
    cm = np.zeros((0, 0), dtype='int64')
    file_to_jer, global_jer = jer(dict(F=ref_durs), dict(F=sys_durs),
                                  dict(F=cm))
    assert file_to_jer['F'] == 0.
    assert global_jer == 0.

    # Real data.
    ref_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'ref.rttm'))
    sys_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'sys.rttm'))
    dur = 1 + max(turn.offset
                  for turn in itertools.chain(ref_turns, sys_turns))
    ref_labels = turns_to_frames(ref_turns, [(0, dur)], step=0.01)
    sys_labels = turns_to_frames(sys_turns, [(0, dur)], step=0.01)
    cm = contingency_matrix(ref_labels, sys_labels)
    ref_durs = ref_labels.sum(axis=0)
    sys_durs = sys_labels.sum(axis=0)
    file_to_jer, global_jer = jer(dict(FILE1=ref_durs), dict(FILE1=sys_durs),
                                  dict(FILE1=cm))
    assert_almost_equal(file_to_jer['FILE1'], 33.24631, 3)
    assert_almost_equal(global_jer, 33.24631, 3)
Beispiel #4
0
def load_rttms(rttm_fns):
    """Load speaker turns from RTTM files.

    Parameters
    ----------
    rttm_fns : list of str
        Paths to RTTM files.

    Returns
    -------
    turns : list of Turn
        Speaker turns.

    file_ids : set
        File ids found in ``rttm_fns``.
    """
    turns = []
    file_ids = set()
    for rttm_fn in rttm_fns:
        if not os.path.exists(rttm_fn):
            error('Unable to open RTTM file: %s' % rttm_fn)
            sys.exit(1)
        try:
            turns_, _, file_ids_ = load_rttm(rttm_fn)
            turns.extend(turns_)
            file_ids.update(file_ids_)
        except IOError as e:
            error('Invalid RTTM file: %s. %s' % (rttm_fn, e))
            sys.exit(1)
    return turns, file_ids
def get_args():
    parser = argparse.ArgumentParser(
        description="""This script truncates the rttm file
                       using UEM file""")
    parser.add_argument("rttm_file",
                        type=str,
                        help="""Input RTTM file.
                            The format of the RTTM file is
                            <type> <file-id> <channel-id> <begin-time> """
                        """<end-time> <NA> <NA> <speaker> <conf>""")
    parser.add_argument("uem_file",
                        type=str,
                        help="""Input UEM file.
                            The format of the UEM file is
                            <file-id> <channel-id> <begin-time> <end-time>""")
    parser.add_argument("rttm_file_write",
                        type=str,
                        help="""output RTTM file.""")
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    rttm_writer = open(args.rttm_file_write, 'w')
    turns, speaker_ids, file_ids = rttm_func.load_rttm(args.rttm_file)
    loaded_uem = load_uem(args.uem_file)
    truncated_turns = trim_turns(turns, loaded_uem)
    rttm_func.write_rttm(args.rttm_file_write, truncated_turns)