コード例 #1
0
def test_that_retrain_calls_eval_method_correctly(mocker):
    mock_classifier = mocker.Mock()
    mock_classifier.fit == mocker.Mock()
    mock_classifier.predict_proba == mocker.Mock()

    mock_eval_method = mocker.Mock(
        return_value={"test_score": np.array([0.8])})

    test_array = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
    widget = SemiSupervisor(
        features=test_array,
        classifier=mock_classifier,
        eval_method=mock_eval_method,
    )

    widget._annotation_loop.send({"source": "", "value": "dummy label 1"})
    widget._annotation_loop.send({"source": "", "value": "dummy label 2"})
    widget.retrain()

    assert mock_eval_method.call_count == 1

    call_arguments = mock_eval_method.call_args[0]
    assert call_arguments[0] is mock_classifier
    assert (call_arguments[1] == test_array[:2, :]).all()
    assert pytest.helpers.same_elements(call_arguments[2],
                                        ["dummy label 1", "dummy label 2"])
    assert widget.model_performance.value == "Score: 0.80"
コード例 #2
0
def test_that_eval_method_is_set_correctly(mocker):
    mock_eval_method = mocker.Mock()

    # test the default
    widget = SemiSupervisor()
    assert widget.eval_method.func is cross_validate

    # test the normal method
    widget = SemiSupervisor(eval_method=mock_eval_method)
    assert widget.eval_method is mock_eval_method

    # test the unhappy case
    with pytest.raises(ValueError):
        widget = SemiSupervisor(eval_method="not a callable")
コード例 #3
0
def test_that_retrain_with_no_labels_sets_warnings(mocker):
    mock_classifier = mocker.Mock()
    mock_classifier.fit == mocker.Mock()
    mock_classifier.predict_proba == mocker.Mock()

    mock_eval_method = mocker.Mock(
        return_value={"test_score": np.array([0.8])})

    test_array = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
    widget = SemiSupervisor(
        features=test_array,
        classifier=mock_classifier,
        eval_method=mock_eval_method,
    )

    widget.retrain()
    assert (widget.model_performance.value ==
            "Score: Not enough labels to retrain.")

    widget._annotation_loop.send({"source": "", "value": "dummy label 1"})

    widget.retrain()
    assert (widget.model_performance.value ==
            "Score: Not enough labels to retrain.")

    widget._annotation_loop.send({"source": "", "value": "dummy label 2"})

    widget.retrain()

    assert mock_eval_method.call_count == 1
    assert widget.model_performance.value == "Score: 0.80"
コード例 #4
0
def test_that_supplied_labels_are_passed_to_queue(mocker):
    test_array = np.array([[1, 2, 3], [1, 2, 3]])
    test_labels = np.array(["hi", "hello"])
    mock_queue = mocker.patch(
        "superintendent.semisupervisor.SimpleLabellingQueue")
    widget = SemiSupervisor(features=test_array, labels=test_labels)  # noqa
    mock_queue.assert_called_once_with(test_array, test_labels)
コード例 #5
0
def test_that_added_labels_are_returned_correctly(mocker):

    test_array = np.array([[1, 2, 3], [1, 2, 3]])
    widget = SemiSupervisor(features=test_array)

    widget._annotation_loop.send({"source": "", "value": "dummy label"})
    widget._annotation_loop.send({"source": "", "value": "dummy label 2"})

    assert widget.new_labels == ["dummy label", "dummy label 2"]
コード例 #6
0
def test_that_sending_undo_into_iterator_calls_undo_on_queue(mocker):
    mock_undo = mocker.patch(
        "superintendent.queueing.SimpleLabellingQueue.undo")

    test_array = np.array([[1, 2, 3], [1, 2, 3]])
    widget = SemiSupervisor(features=test_array)
    widget._annotation_loop.send({"source": "__undo__", "value": None})

    assert mock_undo.call_count == 2
コード例 #7
0
def test_that_sending_labels_into_iterator_submits_them_to_queue(mocker):
    mock_submit = mocker.patch(
        "superintendent.queueing.SimpleLabellingQueue.submit")
    test_array = np.array([[1, 2, 3], [1, 2, 3]])

    widget = SemiSupervisor(features=test_array)
    widget._annotation_loop.send({"source": "", "value": "dummy label"})

    mock_submit.assert_called_once_with(0, "dummy label")
コード例 #8
0
def test_that_the_control_widget_calls_apply_annotation(mocker):
    mock_submit = mocker.patch(
        "superintendent.queueing.SimpleLabellingQueue.submit")
    test_array = np.array([[1, 2, 3], [1, 2, 3]])

    widget = SemiSupervisor(features=test_array, options=["dummy label"])
    widget.input_widget._when_submitted(
        widget.input_widget.control_elements.buttons["dummy label"])

    mock_submit.assert_called_once_with(0, "dummy label")
コード例 #9
0
def test_that_the_event_manager_is_closed(mocker):

    test_array = np.array([[1, 2, 3], [1, 2, 3]])

    mock_event_manager_close = mocker.patch.object(ipyevents.Event, "close")

    widget = SemiSupervisor(features=test_array, keyboard_shortcuts=True)
    widget._annotation_loop.send({"source": "", "value": "dummy label"})
    widget._annotation_loop.send({"source": "", "value": "dummy label"})

    assert mock_event_manager_close.call_count == 1
コード例 #10
0
def test_that_reorder_is_set_correctly(mocker):
    mock_reorder_method = mocker.Mock()

    # test the default
    widget = SemiSupervisor()
    assert widget.reorder is None

    # test passing a string
    widget = SemiSupervisor(reorder="entropy")
    assert widget.reorder is superintendent.prioritisation.entropy

    # test a function
    widget = SemiSupervisor(reorder=mock_reorder_method)
    assert widget.reorder is mock_reorder_method

    # test the unhappy case
    with pytest.raises(NotImplementedError):
        widget = SemiSupervisor(reorder="dummy function name")

    with pytest.raises(ValueError):
        widget = SemiSupervisor(reorder=1)
コード例 #11
0
def test_that_sending_skip_calls_no_queue_method(mocker):
    mock_undo = mocker.patch(
        "superintendent.queueing.SimpleLabellingQueue.undo")
    mock_submit = mocker.patch(
        "superintendent.queueing.SimpleLabellingQueue.submit")

    test_array = np.array([[1, 2, 3], [1, 2, 3]])
    widget = SemiSupervisor(features=test_array)
    widget._annotation_loop.send({"source": "__skip__", "value": None})

    assert mock_undo.call_count == 0
    assert mock_submit.call_count == 0
コード例 #12
0
def test_that_classifier_is_set_correctly(mocker):
    mock_classifier = mocker.Mock()
    mock_classifier.fit == "dummy fit"
    mock_classifier.predict_proba == "dummy predict"

    mock_on_click = mocker.patch("ipywidgets.Button.on_click")

    widget = SemiSupervisor(classifier=mock_classifier)

    assert widget.classifier is mock_classifier
    assert hasattr(widget, "retrain_button")
    assert isinstance(widget.retrain_button, ipywidgets.Button)
    assert ((widget.retrain, ), ) in mock_on_click.call_args_list
コード例 #13
0
def test_that_progressbar_value_is_updated_and_render_finished_called(mocker):
    mock_render_finished = mocker.patch(
        "superintendent.semisupervisor.SemiSupervisor._render_finished")

    test_array = np.array([[1, 2, 3], [1, 2, 3]])
    widget = SemiSupervisor(features=test_array)

    assert widget.progressbar.value == 0
    widget._annotation_loop.send({"source": "", "value": "dummy label"})
    assert widget.progressbar.value == 0.5
    widget._annotation_loop.send({"source": "", "value": "dummy label"})
    assert widget.progressbar.value == 1
    assert mock_render_finished.call_count == 1
コード例 #14
0
def test_that_retrain_calls_reorder_correctly(mocker):

    test_probabilities = np.array([[0.2, 0.3], [0.1, 0.4]])

    mock_eval_method = mocker.Mock(
        return_value={"test_score": np.array([0.8])})

    mock_reordering = mocker.Mock(return_value=[0, 1])

    test_array = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])

    widget = SemiSupervisor(
        features=test_array,
        classifier=LogisticRegression(),
        eval_method=mock_eval_method,
        reorder=mock_reordering,
        shuffle_prop=0.2,
    )

    mocker.patch.object(widget.classifier,
                        "fit",
                        return_value=LogisticRegression())
    mocker.patch.object(widget.classifier,
                        "predict_proba",
                        return_value=test_probabilities)

    widget._annotation_loop.send({"source": "", "value": "dummy label 1"})
    widget._annotation_loop.send({"source": "", "value": "dummy label 2"})
    widget.retrain()

    assert mock_reordering.call_count == 1

    call_args, call_kwargs = mock_reordering.call_args_list[0]

    assert (call_args[0] == test_probabilities).all()
    assert call_kwargs["shuffle_prop"] == 0.2
コード例 #15
0
def test_that_creating_a_widget_works():
    widget = SemiSupervisor()  # noqa
コード例 #16
0
def test_that_shuffle_prop_is_set_correctly(shuffle_prop):
    widget = SemiSupervisor(shuffle_prop=shuffle_prop)
    assert widget.shuffle_prop == shuffle_prop
コード例 #17
0
def test_that_calling_retrain_without_classifier_breaks():
    with pytest.raises(ValueError):
        test_array = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
        widget = SemiSupervisor(features=test_array)
        widget.retrain()