def test_basic(self): model = PseudoModel() dataset = TrainingPseudoDataSet() dataset.load_records() train_dl, val_dl = dataset.partition_data([0.75, 0.25], TrainingPseudoDataSet) # model = SimpleNet() # dataset = TrainingAbnormalDataSet() # dataset.load_records(keep_annotations=True) # train_dl, val_dl = dataset.partition_data([0.75, 0.25], TrainingAbnormalDataSet) batch_aug = BatchAugmenter() # batch_aug.compose([ # MixUpImageWithAnnotations(probability=1.0) # ]) task = AbnormalClassificationTask(model, train_dl, SGD(model.parameters(), lr=0.03, momentum=0.9), batch_augmenter=batch_aug) # task = AbnormalClassificationTask(model, train_dl, Adam(model.parameters(), lr=0.03, betas=(0.9, 0.999), weight_decay=0.01), batch_augmenter=batch_aug) task.max_iter = 100_000_000 # task = TrainingTask() val_hook = PeriodicStepFuncHook(400000, lambda: task.validation(val_dl, model)) checkpoint_hook = CheckpointHook(100000, "test", 1000000, 5) scheduler = LRScheduler.LinearWarmup(0, 3000) scheduler2 = LRScheduler.LambdaLR(0, 3000, lambda step: 1.0) task.register_hook(LogTrainingLoss(frequency=100)) task.register_hook(StepTimer()) # task.register_hook(val_hook) task.register_hook(checkpoint_hook) # task.register_hook(TrainingVisualizationHook(batch=False)) task.register_lrschedulers(scheduler2) task.register_lrschedulers(scheduler) task.begin_or_resume() assert 1 == 1
if __name__ == "__main__": from src.data.abnormal_dataset import TrainingAbnormalDataSet from src.training_tasks.tasks.AbnormalClassificationTask import AbnormalClassificationTask from src.utils.hooks import StepTimer, PeriodicStepFuncHook, LogTrainingLoss from torch import optim from src.training_tasks import BackpropAggregators model = Res50() dataloader = TrainingAbnormalDataSet() dataloader.load_records(keep_annotations=False) train_dl, val_dl = dataloader.partition_data([0.75, 0.25], TrainingAbnormalDataSet) task = AbnormalClassificationTask(model, train_dl, optim.Adam(model.parameters(), lr=0.0001), backward_agg=BackpropAggregators.MeanLosses) task.max_iter = 25000 val_hook = PeriodicStepFuncHook(5000, lambda: task.validation(val_dl, model)) checkpoint_hook = CheckpointHook(1000, "resnet50_test3") task.register_hook(LogTrainingLoss()) task.register_hook(StepTimer()) task.register_hook(val_hook) task.register_hook(checkpoint_hook) task.begin_or_resume()
train_dl, val_dl = dataloader.partition_data([0.75, 0.25], TrainingMulticlassDataset) steps_per_epoch = len(train_dl) // config.artificial_batch_size batch_aug = BatchAugmenter() # batch_aug.compose([MixUpImageWithAnnotations(probability=0.5)]) task = MulticlassDetectionTask(model, train_dl, optim.Adam(model.parameters(), lr=0.00001), backward_agg=BackpropAggregators.MeanLosses, batch_augmenter=batch_aug) # Loss exploded with this optimizer # task = MulticlassDetectionTask(model, train_dl, optim.SGD(model.parameters(), lr=0.003, momentum=0.9), backward_agg=BackpropAggregators.MeanLosses, batch_augmenter=batch_aug) task.max_iter = steps_per_epoch * 2500 validation_iteration = 500 train_acc_hook = PeriodicStepFuncHook(validation_iteration , lambda: task.validation(train_dl, model)) val_hook = PeriodicStepFuncHook(validation_iteration, lambda: task.validation(val_dl, model)) checkpoint_hook = CheckpointHook(validation_iteration, "TimmModel_PostmAPBug_Test1", permanent_checkpoints=validation_iteration, keep_last_n_checkpoints=0) lr_steps = [1.0, 0.1, 0.01, 0.001, 0.0001] steps = [ steps_per_epoch * 50, steps_per_epoch * 100, steps_per_epoch * 150] def lr_stepper(current_step): idx = 0 for step in steps: if current_step <= step: return lr_steps[idx] idx += 1 return lr_steps[-1]
batch_aug = BatchAugmenter() batch_aug.compose([MixUpImage(probability=0.75)]) task = AbnormalClassificationTask( model, train_dl, optim.Adam(model.parameters(), lr=0.0001), backward_agg=BackpropAggregators.MeanLosses, batch_augmenter=batch_aug) # task = AbnormalClassificationTask(model, train_dl, optim.SGD(model.parameters(), lr=0.003, momentum=0.9), # backward_agg=BackpropAggregators.MeanLosses, batch_augmenter=batch_aug) task.max_iter = steps_per_epoch * 25 val_hook = PeriodicStepFuncHook( steps_per_epoch, lambda: task.tim__validation(val_dl, model)) train_acc_hook = PeriodicStepFuncHook( steps_per_epoch, lambda: task.tim__validation(train_dl, model)) checkpoint_hook = CheckpointHook(steps_per_epoch, "TimmModel_BiTRes_X1_TestFive", permanent_checkpoints=steps_per_epoch, keep_last_n_checkpoints=5) lr_steps = [1.0, 0.1, 0.01, 0.001] # steps = [ steps_per_epoch * 10, steps_per_epoch * 15, steps_per_epoch * 20] steps = [-1, -1, steps_per_epoch * 3] def lr_stepper(current_step): idx = 0 for step in steps: