コード例 #1
0
def test_ClassificationInterpretation(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    conf = interp.most_confused()
    print(conf)  # debug
    assert len(conf) == 0 or len(conf) == 2 and (set(conf[0][:2]) == set(
        conf[1][:2]) == {'3', '7'}), f"conf={conf}"
コード例 #2
0
def interpret(learn):
    interp = ClassificationInterpretation.from_learner(learn)

    interp.plot_confusion_matrix()
    plt.savefig(PATH / 'confusion_matrix.jpg')

    interp.plot_top_losses(8)
    plt.savefig(PATH / 'top_losses.jpg')
コード例 #3
0
def test_ClassificationInterpretation(learn):
    this_tests(ClassificationInterpretation)
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    conf = interp.most_confused()
    expect = {'3', '7'}
    assert (len(conf) == 0 or len(conf) == 1 and (set(conf[0][:2]) == expect)
            or len(conf) == 2 and
            (set(conf[0][:2]) == set(conf[1][:2]) == expect)), f"conf={conf}"
コード例 #4
0
ファイル: test_vision_train.py プロジェクト: tterava/fastai
def test_ClassificationInterpretation(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    print(interp.confusion_matrix())
    interp.plot_confusion_matrix()
    plt.show()
    print(interp.most_confused())
    losses, idxs = interp.top_losses()
    print([losses[:10], idxs[:10]])
    interp.plot_top_losses(4)
    plt.show()
コード例 #5
0
def test_ClassificationInterpretation(learn):
    this_tests(ClassificationInterpretation)
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    conf = interp.most_confused()
    expect = {'3', '7'}
    assert (len(conf) == 0 or
            len(conf) == 1 and (set(conf[0][:2]) == expect) or
            len(conf) == 2 and (set(conf[0][:2]) == set(conf[1][:2]) == expect)
    ), f"conf={conf}"
コード例 #6
0
def test_interp(learn):
    this_tests(ClassificationInterpretation.from_learner)
    interp = ClassificationInterpretation.from_learner(learn)
    losses, idxs = interp.top_losses()
    assert len(learn.data.valid_ds) == len(losses) == len(idxs)
コード例 #7
0
def test_confusion_tabular(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    this_tests(interp.confusion_matrix)
コード例 #8
0
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
    "Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)
コード例 #9
0
def test_interp(learn):
    this_tests(ClassificationInterpretation.from_learner)
    interp = ClassificationInterpretation.from_learner(learn)
    losses,idxs = interp.top_losses()
    assert len(learn.data.valid_ds)==len(losses)==len(idxs)
コード例 #10
0
ファイル: learner.py プロジェクト: nargroves/fastai_splunk
def _learner_interpret(learn: Learner,
                       ds_type: DatasetType = DatasetType.Valid):
    "Create a 'ClassificationInterpretation' object from 'learner' on 'ds_type'."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
コード例 #11
0
def test_confusion_tabular(learn, out=True):
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    print(interp.confusion_matrix())
    interp.plot_confusion_matrix()
    if out: plt.show()
コード例 #12
0
def test_confusion_tabular(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    this_tests(interp.confusion_matrix)