コード例 #1
0
ファイル: test_belief_sklearn.py プロジェクト: steppi/indra
def test_missing_source():
    """Check that all source_apis in training data are in source list."""
    lr = LogisticRegression()
    source_list = ['reach', 'sparser']
    cw = CountsScorer(lr, source_list)
    # Should error because test stmts are from signor and signor
    # is not in list
    cw.stmts_to_matrix(test_stmts)
コード例 #2
0
ファイル: test_belief_sklearn.py プロジェクト: steppi/indra
def test_extra_evidence_content():
    """Should raise ValueError if extra_evidence list entries are not
    Evidence objects or empty lists."""
    lr = LogisticRegression()
    source_list = ['reach', 'sparser', 'signor']
    cs = CountsScorer(lr, source_list)
    extra_ev = ([[5]] * (len(test_stmts) - 1)) + [[]]
    x_arr = cs.stmts_to_matrix(test_stmts, extra_evidence=extra_ev)
コード例 #3
0
ファイル: test_belief_sklearn.py プロジェクト: steppi/indra
def test_extra_evidence_length():
    """Should raise ValueError because the extra_evidence list is not the
    same length as the list of statements."""
    lr = LogisticRegression()
    source_list = ['reach', 'sparser', 'signor']
    cs = CountsScorer(lr, source_list)
    extra_ev = [[5]]
    x_arr = cs.stmts_to_matrix(test_stmts, extra_evidence=extra_ev)
コード例 #4
0
ファイル: test_belief_sklearn.py プロジェクト: steppi/indra
def test_stmts_to_matrix():
    """Check that all source_apis in training data are in source list."""
    lr = LogisticRegression()
    source_list = ['reach', 'sparser', 'signor']
    cw = CountsScorer(lr, source_list)
    x_arr = cw.stmts_to_matrix(test_stmts)
    assert isinstance(x_arr, np.ndarray), 'x_arr should be a numpy array'
    assert x_arr.shape == (len(test_stmts), len(source_list)), \
            'stmt matrix dimensions should match test stmts'
    assert set(x_arr.sum(axis=0)) == set([0, 0, len(test_stmts)]), \
           'Signor col should be 1 in every row, other cols 0.'
    # Try again with statement type
    cw = CountsScorer(lr, source_list, use_stmt_type=True)
    num_types = len(cw.stmt_type_map)
    x_arr = cw.stmts_to_matrix(test_stmts)
    assert x_arr.shape == (len(test_stmts), len(source_list) + num_types), \
        'matrix should have a col for sources and other cols for every ' \
        'statement type.'