def test_callbacks_learner(data, model): this_tests(Callback) # single callback in learner constructor learn = Learner(data, model, metrics=accuracy, callback_fns=DummyCallback) with CaptureStdout() as cs: learn.fit_one_cycle(2) check_dummy_metric(cs.out) # list of callbacks in learner constructor learn = Learner(data, model, metrics=accuracy, callback_fns=[DummyCallback]) with CaptureStdout() as cs: learn.fit_one_cycle(2) check_dummy_metric(cs.out) # single callback append learn = Learner(data, model, metrics=accuracy) learn.callbacks.append(DummyCallback(learn)) with CaptureStdout() as cs: learn.fit_one_cycle(2) check_dummy_metric(cs.out) # list of callbacks append: python's append, so append([x]) will not do the right # thing, so it's expected to fail learn = Learner(data, model, metrics=[accuracy]) learn.callbacks.append([DummyCallback(learn)]) error = '' try: with CaptureStdout() as cs: learn.fit_one_cycle(2) except Exception as e: error = str(e) error_pat = "'list' object has no attribute 'on_train_begin'" assert error_pat in error, f"{error_pat} is in the exception:\n{error}"
def test_callbacks_fit(data, model): learn = Learner(data, model, metrics=accuracy) for func in ['fit', 'fit_one_cycle']: fit_func = getattr(learn, func) this_tests(fit_func) # single callback with CaptureStdout() as cs: fit_func(2, callbacks=DummyCallback(learn)) check_dummy_metric(cs.out) # list of callbacks with CaptureStdout() as cs: fit_func(2, callbacks=[DummyCallback(learn)]) check_dummy_metric(cs.out)
def test_gan_trainer(gan_learner): this_tests(GANTrainer) gan_trainer = gan_learner.gan_trainer with CaptureStdout() as cs: gan_learner.fit(1, 1e-4) assert gan_trainer.imgs assert gan_trainer.gen_mode assert gan_trainer.titles
def test_custom_metric_class(): learn = fake_learner(3,2) learn.metrics.append(DummyMetric()) with CaptureStdout() as cs: learn.fit_one_cycle(2) # expecting column header 'dummy', and the metrics per class definition for s in ['dummy', f'{dummy_base_val}.00', f'{dummy_base_val**2}.00']: assert s in cs.out, f"{s} is in the output:\n{cs.out}"
def test_peak_mem_metric(): learn = fake_learner() learn.callbacks.append(PeakMemMetric(learn)) with CaptureStdout() as cs: learn.fit_one_cycle(3, max_lr=1e-2) for s in ['cpu', 'used', 'peak', 'gpu']: assert s in cs.out, f"expecting '{s}' in \n{cs.out}" # epochs 2-3 it shouldn't allocate more general or GPU RAM for s in ['0 0 0 0']: assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
def stop_after_n_batches_run_n_check(learn): with CaptureStdout() as cs: learn.fit_one_cycle(3, max_lr=1e-2) for s in ['train_loss', 'valid_loss']: assert s in cs.out, f"expecting '{s}' in \n{cs.out}" # test that epochs are stopped at epoch 0 assert "\n0" in cs.out, "expecting epoch0" assert "\n1" not in cs.out, "epoch 1 shouldn't run" # test that only n batches were run assert len(learn.recorder.losses) == 2
def test_peak_mem_metric(): learn = fake_learner() learn.callbacks.append(CpuPeakMemMetric(learn)) this_tests(CpuPeakMemMetric) with CaptureStdout() as cs: learn.fit_one_cycle(3, max_lr=1e-2) for s in ["cpu used", "cpu_peak"]: assert s in cs.out, f"expecting '{s}' in \n{cs.out}" # XXX: needs a better test to assert some numbers here (at least >0) # epochs 2-3 it shouldn't allocate more general or CPU RAM for s in ["0 0"]: assert s in cs.out, f"expecting '{s}' in \n{cs.out}"
def stop_after_n_batches_run_n_check(learn, bs, run_n_batches_exp): has_batches = len(learn.data.train_ds)//bs with CaptureStdout() as cs: learn.fit_one_cycle(3, max_lr=1e-2) for s in ['train_loss', 'valid_loss']: assert s in cs.out, f"expecting '{s}' in \n{cs.out}" # test that epochs are stopped at epoch 0 assert "\n0" in cs.out, "expecting epoch0" assert "\n1" not in cs.out, "epoch 1 shouldn't run" # test that only run_n_batches_exp batches were run run_n_batches_got = len(learn.recorder.losses) assert run_n_batches_got == run_n_batches_exp, f"should have run only {run_n_batches_exp}, but got {run_n_batches_got}"
def test_logger(): learn = fake_learner() learn.metrics = [accuracy, error_rate] learn.callback_fns.append(callbacks.CSVLogger) with CaptureStdout() as cs: learn.fit_one_cycle(3) csv_df = learn.csv_logger.read_logged_file() stdout_df = convert_into_dataframe(cs.out) csv_df.drop(columns=['time'], axis=1, inplace=True) pd.testing.assert_frame_equal(csv_df, stdout_df, check_exact=False, check_less_precise=2) recorder_df = create_metrics_dataframe(learn) # XXX: there is a bug in pandas: # https://github.com/pandas-dev/pandas/issues/25068#issuecomment-460014120 # which quite often fails on CI. # once it's resolved can change the setting back to check_less_precise=True (or better =3), until then using =2 as it works, but this check is less good. pd.testing.assert_frame_equal(csv_df, recorder_df, check_exact=False, check_less_precise=2)
def test_logger(): learn = fake_learner() learn.metrics = [accuracy, error_rate] learn.callback_fns.append(callbacks.CSVLogger) with CaptureStdout() as cs: learn.fit_one_cycle(3) csv_df = learn.csv_logger.read_logged_file() recorder_df = create_metrics_dataframe(learn) pd.testing.assert_frame_equal(csv_df, recorder_df, check_exact=False, check_less_precise=True) stdout_df = convert_into_dataframe(cs.out) pd.testing.assert_frame_equal(csv_df, stdout_df, check_exact=False, check_less_precise=True)