Ejemplo n.º 1
0
def test_simpler_classifier():
    # Serves as integration test with
    # feature extraction and multiple concurrent visualisations
    SimplerClassifier = SimpleClassifier.with_dataset(
        DummyClassificationDataLoader)
    m = Main(SimplerClassifier)
    parser = m.argparse(run=False)
    args, _ = parser.parse_known_args()
    args.batch_size = 10
    args.train = True
    args.test = True
    args.max_epochs = 1
    args.limit_train_batches = 100
    args.limit_val_batches = 10
    args.limit_test_batches = 10
    args.test_confusion_matrix = True
    args.visualise_features = "pca"
    args.extract_features_after_layer = "l1"
    args.loss = "cross_entropy"
    m.main(args)

    assert "loss" in m.runner.trainer.model.metrics()
    assert "top1acc" in m.runner.trainer.model.metrics()
    assert "top3acc" in m.runner.trainer.model.metrics()
    assert (Path(m.log_dir) / "figures" / "test" / "l1_pca.png").is_file()
    assert (Path(m.log_dir) / "figures" / "test" /
            "confusion_matrix.png").is_file()
Ejemplo n.º 2
0
def main_and_args() -> Tuple[Main, AttributeDict]:
    m = Main(ExDummyModule)
    parser = m.argparse(run=False)
    args, _ = parser.parse_known_args()
    args.max_epochs = 1
    args.gpus = 0
    args.checkpoint_callback = True
    args.optimization_metric = "loss"
    args.id = "automated_test"
    args.test_ensemble = 0
    args.checkpoint_every_n_steps = 0
    args.monitor_lr = 0
    args.auto_lr_find = 0
    args.auto_scale_batch_size = 0
    args.num_workers = 1
    args.batch_size = 4

    return m, args
Ejemplo n.º 3
0
def test_simple_classifier():
    m = Main(SimpleClassifier)
    parser = m.argparse(run=False)
    args, _ = parser.parse_known_args()
    args.train = True
    args.test = True
    args.max_epochs = 1
    args.limit_train_batches = 100
    args.limit_val_batches = 10
    args.limit_test_batches = 10
    args.test_confusion_matrix = True
    args.logging_backend = "wandb"
    m.main(args)

    assert "loss" in m.runner.trainer.model.metrics()
    assert "top1acc" in m.runner.trainer.model.metrics()
    assert "top3acc" in m.runner.trainer.model.metrics()
    assert (Path(m.log_dir) / "figures" / "test" /
            "confusion_matrix.png").is_file()
Ejemplo n.º 4
0
 def test_default_id(self):
     """Test that a default id is given"""
     m = Main(DummyModule)
     parser = m.argparse(run=False)
     args, _ = parser.parse_known_args()
     assert args.id == "unnamed"