def test_simpler_classifier(): # Serves as integration test with # feature extraction and multiple concurrent visualisations SimplerClassifier = SimpleClassifier.with_dataset( DummyClassificationDataLoader) m = Main(SimplerClassifier) parser = m.argparse(run=False) args, _ = parser.parse_known_args() args.batch_size = 10 args.train = True args.test = True args.max_epochs = 1 args.limit_train_batches = 100 args.limit_val_batches = 10 args.limit_test_batches = 10 args.test_confusion_matrix = True args.visualise_features = "pca" args.extract_features_after_layer = "l1" args.loss = "cross_entropy" m.main(args) assert "loss" in m.runner.trainer.model.metrics() assert "top1acc" in m.runner.trainer.model.metrics() assert "top3acc" in m.runner.trainer.model.metrics() assert (Path(m.log_dir) / "figures" / "test" / "l1_pca.png").is_file() assert (Path(m.log_dir) / "figures" / "test" / "confusion_matrix.png").is_file()
def main_and_args() -> Tuple[Main, AttributeDict]: m = Main(ExDummyModule) parser = m.argparse(run=False) args, _ = parser.parse_known_args() args.max_epochs = 1 args.gpus = 0 args.checkpoint_callback = True args.optimization_metric = "loss" args.id = "automated_test" args.test_ensemble = 0 args.checkpoint_every_n_steps = 0 args.monitor_lr = 0 args.auto_lr_find = 0 args.auto_scale_batch_size = 0 args.num_workers = 1 args.batch_size = 4 return m, args
def test_simple_classifier(): m = Main(SimpleClassifier) parser = m.argparse(run=False) args, _ = parser.parse_known_args() args.train = True args.test = True args.max_epochs = 1 args.limit_train_batches = 100 args.limit_val_batches = 10 args.limit_test_batches = 10 args.test_confusion_matrix = True args.logging_backend = "wandb" m.main(args) assert "loss" in m.runner.trainer.model.metrics() assert "top1acc" in m.runner.trainer.model.metrics() assert "top3acc" in m.runner.trainer.model.metrics() assert (Path(m.log_dir) / "figures" / "test" / "confusion_matrix.png").is_file()
def test_help(self, capsys, caplog): """Test that command help works""" # Nothing is logged caplog.clear() with caplog.at_level(logging.WARNING): Main(DummyModule).argparse([]) assert caplog.messages == [] or (len(caplog.messages) == 1 and "Missing log" in caplog.messages[0]) captured = capsys.readouterr() assert captured.out == "" assert captured.err == "" # Help is printed caplog.clear() with pytest.raises(SystemExit), caplog.at_level(logging.WARNING): Main(DummyModule).argparse(["--help"]) # Help message was neither logged nor in stderr assert len(caplog.messages) == 0 captured = capsys.readouterr() assert captured.err == "" # Help is in stdout help_msg = captured.out assert len(help_msg) > 0 # Flow args for msg in [ "--hparamsearch", "--train", "--test", "--profile_model", ]: assert msg in help_msg # General args for msg in ["--id", "--logging_backend", "--optimization_metric"]: assert msg in help_msg # Pytorch Lightning args for msg in [ "--logger", "--gpus", "--accumulate_grad_batches", "--max_epochs", "--limit_train_batches", "--precision", "--resume_from_checkpoint", "--benchmark", "--auto_lr_find", "--auto_scale_batch_size", ]: assert msg in help_msg # Module args for msg in [ "--loss", "--learning_rate", "--batch_size", "--hidden_dim" ]: assert msg in help_msg
def test_default_id(self): """Test that a default id is given""" m = Main(DummyModule) parser = m.argparse(run=False) args, _ = parser.parse_known_args() assert args.id == "unnamed"