def test_simpler_classifier(): # Serves as integration test with # feature extraction and multiple concurrent visualisations SimplerClassifier = SimpleClassifier.with_dataset( DummyClassificationDataLoader) m = Main(SimplerClassifier) parser = m.argparse(run=False) args, _ = parser.parse_known_args() args.batch_size = 10 args.train = True args.test = True args.max_epochs = 1 args.limit_train_batches = 100 args.limit_val_batches = 10 args.limit_test_batches = 10 args.test_confusion_matrix = True args.visualise_features = "pca" args.extract_features_after_layer = "l1" args.loss = "cross_entropy" m.main(args) assert "loss" in m.runner.trainer.model.metrics() assert "top1acc" in m.runner.trainer.model.metrics() assert "top3acc" in m.runner.trainer.model.metrics() assert (Path(m.log_dir) / "figures" / "test" / "l1_pca.png").is_file() assert (Path(m.log_dir) / "figures" / "test" / "confusion_matrix.png").is_file()
def main_and_args() -> Tuple[Main, AttributeDict]: m = Main(ExDummyModule) parser = m.argparse(run=False) args, _ = parser.parse_known_args() args.max_epochs = 1 args.gpus = 0 args.checkpoint_callback = True args.optimization_metric = "loss" args.id = "automated_test" args.test_ensemble = 0 args.checkpoint_every_n_steps = 0 args.monitor_lr = 0 args.auto_lr_find = 0 args.auto_scale_batch_size = 0 args.num_workers = 1 args.batch_size = 4 return m, args
def test_simple_classifier(): m = Main(SimpleClassifier) parser = m.argparse(run=False) args, _ = parser.parse_known_args() args.train = True args.test = True args.max_epochs = 1 args.limit_train_batches = 100 args.limit_val_batches = 10 args.limit_test_batches = 10 args.test_confusion_matrix = True args.logging_backend = "wandb" m.main(args) assert "loss" in m.runner.trainer.model.metrics() assert "top1acc" in m.runner.trainer.model.metrics() assert "top3acc" in m.runner.trainer.model.metrics() assert (Path(m.log_dir) / "figures" / "test" / "confusion_matrix.png").is_file()
def test_default_id(self): """Test that a default id is given""" m = Main(DummyModule) parser = m.argparse(run=False) args, _ = parser.parse_known_args() assert args.id == "unnamed"