Exemple #1
0
def test_callbacks_learner(data, model):
    this_tests(Callback)

    # single callback in learner constructor
    learn = Learner(data, model, metrics=accuracy, callback_fns=DummyCallback)
    with CaptureStdout() as cs: learn.fit_one_cycle(2)
    check_dummy_metric(cs.out)

    # list of callbacks in learner constructor
    learn = Learner(data, model, metrics=accuracy, callback_fns=[DummyCallback])
    with CaptureStdout() as cs: learn.fit_one_cycle(2)
    check_dummy_metric(cs.out)

    # single callback append
    learn = Learner(data, model, metrics=accuracy)
    learn.callbacks.append(DummyCallback(learn))
    with CaptureStdout() as cs: learn.fit_one_cycle(2)
    check_dummy_metric(cs.out)

    # list of callbacks append: python's append, so append([x]) will not do the right
    # thing, so it's expected to fail
    learn = Learner(data, model, metrics=[accuracy])
    learn.callbacks.append([DummyCallback(learn)])
    error = ''
    try:
        with CaptureStdout() as cs: learn.fit_one_cycle(2)
    except Exception as e:
        error = str(e)
    error_pat = "'list' object has no attribute 'on_train_begin'"
    assert error_pat in error, f"{error_pat} is in the exception:\n{error}"
Exemple #2
0
def test_callbacks_fit(data, model):
    learn = Learner(data, model, metrics=accuracy)

    for func in ['fit', 'fit_one_cycle']:
        fit_func = getattr(learn, func)
        this_tests(fit_func)

        # single callback
        with CaptureStdout() as cs: fit_func(2, callbacks=DummyCallback(learn))
        check_dummy_metric(cs.out)

        # list of callbacks
        with CaptureStdout() as cs: fit_func(2, callbacks=[DummyCallback(learn)])
        check_dummy_metric(cs.out)
def test_gan_trainer(gan_learner):
    this_tests(GANTrainer)
    gan_trainer = gan_learner.gan_trainer
    with CaptureStdout() as cs: gan_learner.fit(1, 1e-4)
    assert gan_trainer.imgs
    assert gan_trainer.gen_mode
    assert gan_trainer.titles
Exemple #4
0
def test_custom_metric_class():
    learn = fake_learner(3,2)
    learn.metrics.append(DummyMetric())
    with CaptureStdout() as cs: learn.fit_one_cycle(2)
    # expecting column header 'dummy', and the metrics per class definition
    for s in ['dummy', f'{dummy_base_val}.00', f'{dummy_base_val**2}.00']:
        assert s in cs.out, f"{s} is in the output:\n{cs.out}"
Exemple #5
0
def test_peak_mem_metric():
    learn = fake_learner()
    learn.callbacks.append(PeakMemMetric(learn))
    with CaptureStdout() as cs:
        learn.fit_one_cycle(3, max_lr=1e-2)
    for s in ['cpu', 'used', 'peak', 'gpu']:
        assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
    # epochs 2-3 it shouldn't allocate more general or GPU RAM
    for s in ['0         0         0         0']:
        assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
def stop_after_n_batches_run_n_check(learn):
    with CaptureStdout() as cs:
        learn.fit_one_cycle(3, max_lr=1e-2)
    for s in ['train_loss', 'valid_loss']:
        assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
    # test that epochs are stopped at epoch 0
    assert "\n0" in cs.out, "expecting epoch0"
    assert "\n1" not in cs.out, "epoch 1 shouldn't run"

    # test that only n batches were run
    assert len(learn.recorder.losses) == 2
Exemple #7
0
def test_peak_mem_metric():
    learn = fake_learner()
    learn.callbacks.append(CpuPeakMemMetric(learn))
    this_tests(CpuPeakMemMetric)
    with CaptureStdout() as cs:
        learn.fit_one_cycle(3, max_lr=1e-2)
    for s in ["cpu used", "cpu_peak"]:
        assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
    # XXX: needs a better test to assert some numbers here (at least >0)
    # epochs 2-3 it shouldn't allocate more general or CPU RAM
    for s in ["0         0"]:
        assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
def stop_after_n_batches_run_n_check(learn, bs, run_n_batches_exp):
    has_batches = len(learn.data.train_ds)//bs
    with CaptureStdout() as cs:
        learn.fit_one_cycle(3, max_lr=1e-2)
    for s in ['train_loss', 'valid_loss']:
        assert s in cs.out, f"expecting '{s}' in \n{cs.out}"

    # test that epochs are stopped at epoch 0
    assert "\n0" in cs.out, "expecting epoch0"
    assert "\n1" not in cs.out, "epoch 1 shouldn't run"

    # test that only run_n_batches_exp batches were run
    run_n_batches_got = len(learn.recorder.losses)
    assert run_n_batches_got == run_n_batches_exp, f"should have run only {run_n_batches_exp}, but got {run_n_batches_got}"
Exemple #9
0
def test_logger():
    learn = fake_learner()
    learn.metrics = [accuracy, error_rate]
    learn.callback_fns.append(callbacks.CSVLogger)
    with CaptureStdout() as cs: learn.fit_one_cycle(3)
    csv_df = learn.csv_logger.read_logged_file()
    stdout_df = convert_into_dataframe(cs.out)
    csv_df.drop(columns=['time'], axis=1, inplace=True)
    pd.testing.assert_frame_equal(csv_df, stdout_df, check_exact=False, check_less_precise=2)
    recorder_df = create_metrics_dataframe(learn)
    # XXX: there is a bug in pandas:
    # https://github.com/pandas-dev/pandas/issues/25068#issuecomment-460014120
    # which quite often fails on CI.
    # once it's resolved can change the setting back to check_less_precise=True (or better =3), until then using =2 as it works, but this check is less good.
    pd.testing.assert_frame_equal(csv_df, recorder_df, check_exact=False, check_less_precise=2)
Exemple #10
0
def test_logger():
    learn = fake_learner()
    learn.metrics = [accuracy, error_rate]
    learn.callback_fns.append(callbacks.CSVLogger)
    with CaptureStdout() as cs:
        learn.fit_one_cycle(3)
    csv_df = learn.csv_logger.read_logged_file()
    recorder_df = create_metrics_dataframe(learn)
    pd.testing.assert_frame_equal(csv_df,
                                  recorder_df,
                                  check_exact=False,
                                  check_less_precise=True)
    stdout_df = convert_into_dataframe(cs.out)
    pd.testing.assert_frame_equal(csv_df,
                                  stdout_df,
                                  check_exact=False,
                                  check_less_precise=True)