Beispiel #1
0
    def test_basic(self):
        model = PseudoModel()
        dataset = TrainingPseudoDataSet()
        dataset.load_records()
        train_dl, val_dl = dataset.partition_data([0.75, 0.25],
                                                  TrainingPseudoDataSet)

        # model = SimpleNet()
        # dataset = TrainingAbnormalDataSet()
        # dataset.load_records(keep_annotations=True)
        # train_dl, val_dl = dataset.partition_data([0.75, 0.25], TrainingAbnormalDataSet)

        batch_aug = BatchAugmenter()
        # batch_aug.compose([
        #     MixUpImageWithAnnotations(probability=1.0)
        # ])

        task = AbnormalClassificationTask(model,
                                          train_dl,
                                          SGD(model.parameters(),
                                              lr=0.03,
                                              momentum=0.9),
                                          batch_augmenter=batch_aug)
        # task = AbnormalClassificationTask(model, train_dl, Adam(model.parameters(), lr=0.03, betas=(0.9, 0.999), weight_decay=0.01), batch_augmenter=batch_aug)

        task.max_iter = 100_000_000
        # task = TrainingTask()

        val_hook = PeriodicStepFuncHook(400000,
                                        lambda: task.validation(val_dl, model))
        checkpoint_hook = CheckpointHook(100000, "test", 1000000, 5)

        scheduler = LRScheduler.LinearWarmup(0, 3000)
        scheduler2 = LRScheduler.LambdaLR(0, 3000, lambda step: 1.0)

        task.register_hook(LogTrainingLoss(frequency=100))
        task.register_hook(StepTimer())
        # task.register_hook(val_hook)
        task.register_hook(checkpoint_hook)
        # task.register_hook(TrainingVisualizationHook(batch=False))

        task.register_lrschedulers(scheduler2)
        task.register_lrschedulers(scheduler)

        task.begin_or_resume()

        assert 1 == 1


if __name__ == "__main__":
    from src.data.abnormal_dataset import TrainingAbnormalDataSet
    from src.training_tasks.tasks.AbnormalClassificationTask import AbnormalClassificationTask
    from src.utils.hooks import StepTimer, PeriodicStepFuncHook, LogTrainingLoss
    from torch import optim

    from src.training_tasks import BackpropAggregators

    model = Res50()

    dataloader = TrainingAbnormalDataSet()
    dataloader.load_records(keep_annotations=False)

    train_dl, val_dl = dataloader.partition_data([0.75, 0.25], TrainingAbnormalDataSet)

    task = AbnormalClassificationTask(model, train_dl, optim.Adam(model.parameters(), lr=0.0001), backward_agg=BackpropAggregators.MeanLosses)
    task.max_iter = 25000

    val_hook = PeriodicStepFuncHook(5000, lambda: task.validation(val_dl, model))
    checkpoint_hook = CheckpointHook(1000, "resnet50_test3")

    task.register_hook(LogTrainingLoss())
    task.register_hook(StepTimer())
    task.register_hook(val_hook)
    task.register_hook(checkpoint_hook)

    task.begin_or_resume()