Ejemplo n.º 1
0
def print_cm(cm, ref_classes, sys_classes, norm=True):
    """Print confusion_matrix.

    Parameters
    ----------
    cm : ndarray, (n_ref_classes, n_sys_classes)
        Contingency table between reference and system labelings.

    ref_classes : ndarray, (n_ref_classes,)
        Reference classes.

    sys_classes : ndarray, (n_sys_classes,)
        System classes.

    norm : bool, optional
        If True, normalize rows of confusion matrix to sum to 1.
        (Default: False)
    """
    cm, ref_classes, sys_clases = contingency_matrix(ref_labels, sys_labels)
    if norm:
        marginals = cm.sum(axis=1, dtype='float64')
        cm = cm / np.expand_dims(marginals, axis=1)
    cm = cm.tolist()
    for ii, label in enumerate(ref_classes):
        cm[ii].insert(0, label)
    logger.info(tabulate(cm, headers=[''] + list(sys_classes)))
Ejemplo n.º 2
0
def test_conditional_entropy():
    x, y = make_labels()

    # Test from inputs.
    ce = conditional_entropy(x, y)
    assert_almost_equal(ce, 1.5824, 3)

    # Test from CM.
    cm = contingency_matrix(x, y)
    ce = conditional_entropy(None, None, cm)
    assert_almost_equal(ce, 1.5824, 3)
Ejemplo n.º 3
0
def test_goodman_kruskal_tau():
    x, y = make_labels()

    # Test from inputs.
    tau_rs, tau_sr = goodman_kruskal_tau(x, y)
    assert_almost_equal(tau_rs, 0.001223, 5)
    assert_almost_equal(tau_sr, 0.001223, 5)

    # Test from CM.
    cm = contingency_matrix(x, y)
    tau_rs, tau_sr = goodman_kruskal_tau(None, None, cm)
    assert_almost_equal(tau_rs, 0.001223, 5)
    assert_almost_equal(tau_sr, 0.001223, 5)
Ejemplo n.º 4
0
def test_mutual_information():
    x, y = make_labels()

    # Test from inputs.
    mi, nmi = mutual_information(x, y)
    assert_almost_equal(mi, 0.001767, 5)
    assert_almost_equal(nmi, 0.001116, 5)

    # Test from CM.
    cm = contingency_matrix(x, y)
    mi, nmi = mutual_information(None, None, cm)
    assert_almost_equal(mi, 0.001767, 5)
    assert_almost_equal(nmi, 0.001116, 5)
Ejemplo n.º 5
0
def test_bcubed():
    x, y = make_labels()

    # Test from inputs.
    p, r, f1 = bcubed(x, y)
    assert_almost_equal(p, 0.3345, 3)
    assert_almost_equal(r, 0.3356, 3)
    assert_almost_equal(f1, 0.3351, 3)

    # Test from CM.
    cm = contingency_matrix(x, y)
    p, r, f1 = bcubed(None, None, cm)
    assert_almost_equal(p, 0.3345, 3)
    assert_almost_equal(r, 0.3356, 3)
    assert_almost_equal(f1, 0.3351, 3)
Ejemplo n.º 6
0
def test_jer():
    # Check input validation.
    with assert_raises_regex(ValueError, 'All passed dicts must have same'):
        jer(dict(), dict(F1=np.zeros(1)), dict())
        jer(dict(F1=np.zeros(1)), dict(), dict())
        jer(dict(), dict(), dict(F1=np.zeros(1)))

    # Edge case: no reference speech.
    ref_durs = np.array([], dtype='int64')
    sys_durs = np.array([5, 5], dtype='int64')
    cm = np.zeros((0, 2), dtype='int64')
    file_to_jer, global_jer = jer(dict(F=ref_durs), dict(F=sys_durs),
                                  dict(F=cm))
    assert file_to_jer['F'] == 100.
    assert global_jer == 100.

    # Edge case: no system speech.
    ref_durs = np.array([5, 5], dtype='int64')
    sys_durs = np.array([], dtype='int64')
    cm = np.zeros((2, 0), dtype='int64')
    file_to_jer, global_jer = jer(dict(F=ref_durs), dict(F=sys_durs),
                                  dict(F=cm))
    assert file_to_jer['F'] == 100.
    assert global_jer == 100.

    # Edge case: no reference OR system speech.
    ref_durs = np.array([], dtype='int64')
    sys_durs = np.array([], dtype='int64')
    cm = np.zeros((0, 0), dtype='int64')
    file_to_jer, global_jer = jer(dict(F=ref_durs), dict(F=sys_durs),
                                  dict(F=cm))
    assert file_to_jer['F'] == 0.
    assert global_jer == 0.

    # Real data.
    ref_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'ref.rttm'))
    sys_turns, _, _ = load_rttm(os.path.join(TEST_DIR, 'sys.rttm'))
    dur = 1 + max(turn.offset
                  for turn in itertools.chain(ref_turns, sys_turns))
    ref_labels = turns_to_frames(ref_turns, [(0, dur)], step=0.01)
    sys_labels = turns_to_frames(sys_turns, [(0, dur)], step=0.01)
    cm = contingency_matrix(ref_labels, sys_labels)
    ref_durs = ref_labels.sum(axis=0)
    sys_durs = sys_labels.sum(axis=0)
    file_to_jer, global_jer = jer(dict(FILE1=ref_durs), dict(FILE1=sys_durs),
                                  dict(FILE1=cm))
    assert_almost_equal(file_to_jer['FILE1'], 33.24631, 3)
    assert_almost_equal(global_jer, 33.24631, 3)
Ejemplo n.º 7
0
def test_contingency_matrix():
    # Test exceptions.
    with assert_raises_regex(
            ValueError, 'ref_labels and sys_labels should either both be 1D'):
        contingency_matrix(np.zeros(5), np.zeros((5, 2)))
    with assert_raises_regex(ValueError,
                             'ref_labels and sys_labels must have same size'):
        contingency_matrix(np.arange(5), np.arange(6))

    # Test 1-D inputs.
    X, Y = make_labels()
    cm = contingency_matrix(X, Y)
    cm_expected = np.array([[106, 114, 117], [110, 130, 105], [92, 118, 108]])
    assert_equal(cm, cm_expected)

    # Test 2-D inputs.
    X, Y = make_labels(one_hot=True)
    cm = contingency_matrix(X, Y)
    assert_equal(cm, cm_expected)
Ejemplo n.º 8
0
if __name__ == '__main__':
    # Parse command line arguments.
    parser = argparse.ArgumentParser(
        description='Score RTTM.',
        add_help=True,
        usage='%(prog)s [options] ref_rttm sys_rttm')
    parser.add_argument('ref_rttm', nargs=None, help='reference RTTM')
    parser.add_argument('sys_rttm', nargs=None, help='system RTTM')
    parser.add_argument('--step',
                        nargs=None,
                        default=0.010,
                        type=float,
                        metavar='FLOAT',
                        help='step size in seconds (Default: %(default)s)')
    parser.add_argument('--norm',
                        action='store_true',
                        default=False,
                        help='normalize rows')
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s ' + VERSION)
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    args = parser.parse_args()

    ref_labels, sys_labels = rttms_to_frames(args.ref_rttm, args.sys_rttm,
                                             args.step)
    cm, ref_classes, sys_classes = contingency_matrix(ref_labels, sys_labels)
    print_cm(cm, ref_classes, sys_classes, args.norm)