def test_densenet(): from torchvision.models import densenet121 net1 = E.DenseNet121Encoder(pretrained=False) net2 = densenet121(pretrained=False) net2.classifier = None print(count_parameters(net1), count_parameters(net2))
def test_cls_models(model_name): model = get_model(model_name=model_name, num_classes=4).eval() print(model_name, count_parameters(model)) return x = torch.rand((1, 3, 224, 224)) output = model(x) assert output['logits'].size(1) == 4 assert output['features'].size(1) == model.features_size print(model_name, count_parameters(model))
def test_unet_decoder(): encoder = E.Resnet18Encoder(pretrained=False, layers=[0, 1, 2, 3, 4]) decoder = D.UNetDecoder(encoder.channels) x = torch.randn((16, 3, 256, 256)) model = nn.Sequential(encoder, decoder) output = model(x) print(count_parameters(encoder)) print(count_parameters(decoder)) for o in output: print(o.size(), o.mean(), o.std())
def test_unet_encoder_decoder(): encoder = E.UnetEncoder(3, 32, 5) decoder = D.UNetDecoder(encoder.channels) x = torch.randn((2, 3, 256, 256)) model = nn.Sequential(encoder, decoder).eval() output = model(x) print(count_parameters(encoder)) print(count_parameters(decoder)) for o in output: print(o.size(), o.mean(), o.std())
def test_inceptionv4_encoder(): backbone = inceptionv4(pretrained=False) backbone.last_linear = None net = E.InceptionV4Encoder(pretrained=False, layers=[0, 1, 2, 3, 4]).cuda() print(count_parameters(backbone)) print(count_parameters(net)) x = torch.randn((4, 3, 512, 512)).cuda() out = net(x) for fm in out: print(fm.size())
def test_b4_effunet32_s2(): model = maybe_cuda(b4_effunet32_s2()) x = maybe_cuda(torch.rand((2, 3, 512, 512))) output = model(x) print(count_parameters(model)) for key, value in output.items(): print(key, value.size(), value.mean(), value.std())
def test_models(model_name): model = maybe_cuda(get_model(model_name, pretrained=False).eval()) x = maybe_cuda(torch.rand((2, 3, 512, 512))) output = model(x) print(model_name, count_parameters(model)) for key, value in output.items(): print(key, value.size(), value.mean(), value.std())
def test_efficientb4_fpncatv2_256(): d = efficientb4_fpncatv2_256().cuda().eval() print(count_parameters(d)) images = torch.rand(4, 6, 512, 512).cuda() i = d(images) print(i[OUTPUT_MASK_KEY].size())
def test_inceptionv4_fpncatv2_256(): d = inceptionv4_fpncatv2_256().cuda().eval() print(count_parameters(d)) images = torch.rand(2, 6, 512, 512).cuda() i = d(images) print(i[OUTPUT_MASK_KEY].size())
def test_efficient_net(): from pytorch_toolbelt.utils.torch_utils import count_parameters num_classes = 1001 x = torch.randn((1, 3, 600, 600)).cuda() for model_fn in [ efficient_net_b0, # efficient_net_b1, # efficient_net_b2, # efficient_net_b3, # efficient_net_b4, # efficient_net_b5, # efficient_net_b6, # efficient_net_b7 ]: print("=======", model_fn.__name__, "=======") model = model_fn(num_classes).eval().cuda() print(count_parameters(model)) print(model) print() print() output = model(x) print(output.size())
def test_efficient_net_group_norm(): from pytorch_toolbelt.utils.torch_utils import count_parameters num_classes = 1001 x = torch.randn((1, 3, 600, 600)).cuda() for model_fn in [ efficient_net_b0, efficient_net_b1, efficient_net_b2, efficient_net_b3, efficient_net_b4, efficient_net_b5, efficient_net_b6, efficient_net_b7, ]: print("=======", model_fn.__name__, "=======") agn_params = {"num_groups": 8, "activation": ACT_HARD_SWISH} model = (model_fn(num_classes, abn_block=AGN, abn_params=agn_params).eval().cuda()) print(count_parameters(model)) # print(model) print() print() output = model(x) print(output.size()) torch.cuda.empty_cache()
def test_hrnet18_runet64(): net = hrnet18_runet64().eval() print(net) print(count_parameters(net)) input = torch.rand((2, 3, 256, 256)) output = net(input) mask = output[OUTPUT_MASK_KEY]
def test_b6_unet32_s2_rdtc(): model = b6_unet32_s2_rdtc(need_supervision_masks=True) model = maybe_cuda(model.eval()) x = maybe_cuda(torch.rand((2, 3, 512, 512))) output = model(x) print(count_parameters(model)) for key, value in output.items(): print(key, value.size(), value.mean(), value.std())
def test_U2NET(): model = U2NET() model = maybe_cuda(model.eval()) x = maybe_cuda(torch.rand((2, 3, 512, 512))) output = model(x) print(count_parameters(model)) for key, value in output.items(): print(key, value.size(), value.mean(), value.std())
def test_fpn_cat(): channels = [256, 512, 1024, 2048] sizes = [64, 32, 16, 8] net = FPNCatDecoder(channels, 5).eval() input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] output = net(input) print(output.size(), output.mean(), output.std()) print(count_parameters(net))
def test_jit_trace(encoder, encoder_params): model = encoder(**encoder_params).eval() print(model.__class__.__name__, count_parameters(model)) print(model.strides) print(model.channels) dummy_input = torch.rand((1, 3, 256, 256)) dummy_input = maybe_cuda(dummy_input) model = maybe_cuda(model) model = torch.jit.trace(model, dummy_input, check_trace=True)
def test_wider_resnet(): for fn in [ wider_resnet_16_a2, wider_resnet_16, wider_resnet_20, wider_resnet_20_a2, wider_resnet_38, wider_resnet_38_a2 ]: net = fn().eval() print(count_parameters(net)) x = torch.randn((1, 3, 512, 512)) out = net(x) for o in out: print(o.size())
def test_inception_unet_like_selim(): d = inceptionv4_unet_v2().cuda().eval() print(count_parameters(d)) print(d.decoder.decoder_features) print(d.decoder.bottlenecks) print(d.decoder.decoder_stages) images = torch.rand(4, 6, 512, 512).cuda() i = d(images) print(i[OUTPUT_MASK_KEY].size())
def test_fpn_sum_with_encoder(): x = torch.randn((16, 3, 256, 256)) encoder = E.Resnet18Encoder(pretrained=False) decoder = FPNSumDecoder(encoder.channels, 128) model = nn.Sequential(encoder, decoder) output = model(x) print(count_parameters(decoder)) for o in output: print(o.size(), o.mean(), o.std())
def test_fpn_sum(): channels = [256, 512, 1024, 2048] sizes = [64, 32, 16, 8] decoder = FPNSumDecoder(channels, 5).eval() x = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] outputs = decoder(x) print(count_parameters(decoder)) for o in outputs: print(o.size(), o.mean(), o.std())
def test_fpn_cat_with_dsv(): channels = [256, 512, 1024, 2048] sizes = [64, 32, 16, 8] net = FPNCatDecoder(channels, output_channels=5, dsv_channels=5).eval() input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] output, dsv_masks = net(input) print(output.size(), output.mean(), output.std()) for dsv in dsv_masks: print(dsv.size(), dsv.mean(), dsv.std()) assert dsv.size(1) == 5 print(count_parameters(net))
def test_encoders(encoder: EncoderModule, encoder_params): with torch.no_grad(): net = encoder(**encoder_params).eval() print(net.__class__.__name__, count_parameters(net)) input = torch.rand((4, 3, 512, 512)) input = maybe_cuda(input) net = maybe_cuda(net) output = net(input) assert len(output) == len(net.output_filters) for feature_map, expected_stride, expected_channels in zip( output, net.output_strides, net.output_filters): assert feature_map.size(1) == expected_channels assert feature_map.size(2) * expected_stride == 512 assert feature_map.size(3) * expected_stride == 512
def test_encoders(encoder: E.EncoderModule, encoder_params): net = encoder(**encoder_params).eval() print(net.__class__.__name__, count_parameters(net)) print(net.output_strides) print(net.out_channels) x = torch.rand((4, 3, 256, 256)) x = maybe_cuda(x) net = maybe_cuda(net) output = net(x) assert len(output) == len(net.channels) for feature_map, expected_stride, expected_channels in zip( output, net.strides, net.channels): assert feature_map.size(1) == expected_channels assert feature_map.size(2) * expected_stride == 256 assert feature_map.size(3) * expected_stride == 256
def test_hourglass_encoder(encoder, encoder_params): net = encoder(**encoder_params).eval() print(repr(net), count_parameters(net)) print("Strides ", net.strides) print("Channels", net.channels) x = torch.rand((4, 3, 256, 256)) x = maybe_cuda(x) net = maybe_cuda(net) output = net(x) assert len(output) == len(net.channels) for feature_map, expected_stride, expected_channels in zip( output, net.strides, net.channels): assert feature_map.size(1) == expected_channels assert feature_map.size(2) * expected_stride == 256 assert feature_map.size(3) * expected_stride == 256
def test_unet_encoder(): net = E.UnetEncoder().eval() print(net.__class__.__name__, count_parameters(net)) print(net.output_strides) print(net.channels) x = torch.rand((4, 3, 256, 256)) x = maybe_cuda(x) net = maybe_cuda(net) output = net(x) assert len(output) == len(net.output_filters) for feature_map, expected_stride, expected_channels in zip( output, net.output_strides, net.output_filters): print(feature_map.size(), feature_map.mean(), feature_map.std()) assert feature_map.size(1) == expected_channels assert feature_map.size(2) * expected_stride == 256 assert feature_map.size(3) * expected_stride == 256
def test_supervised_hourglass_encoder(encoder, encoder_params): net = encoder(**encoder_params).eval() print(net.__class__.__name__, count_parameters(net)) print(net.output_strides) print(net.out_channels) x = torch.rand((4, 3, 256, 256)) x = maybe_cuda(x) net = maybe_cuda(net) output, supervision = net(x) assert len(output) == len(net.output_filters) assert len(supervision) == len(net.output_filters) - 2 for feature_map, expected_stride, expected_channels in zip( output, net.output_strides, net.output_filters): assert feature_map.size(1) == expected_channels assert feature_map.size(2) * expected_stride == 256 assert feature_map.size(3) * expected_stride == 256
def test_onnx_export(encoder, encoder_params): import onnx model = encoder(**encoder_params).eval() print(model.__class__.__name__, count_parameters(model)) print(model.strides) print(model.channels) dummy_input = torch.rand((1, 3, 256, 256)) dummy_input = maybe_cuda(dummy_input) model = maybe_cuda(model) input_names = ["image"] output_names = [f"feature_map_{i}" for i in range(len(model.channels))] torch.onnx.export(model, dummy_input, "tmp.onnx", verbose=True, input_names=input_names, output_names=output_names) model = onnx.load("tmp.onnx") onnx.checker.check_model(model)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration") parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("--cache", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-m", "--model", type=str, default="resnet34", help="") parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64") parser.add_argument( "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64" ) parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") parser.add_argument( "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement" ) parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument( "--modification-type-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" ) parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--mixup", action="store_true") parser.add_argument("--cutmix", action="store_true") parser.add_argument("--tsa", action="store_true") parser.add_argument("--fold", default=None, type=int) parser.add_argument("-s", "--scheduler", default=None, type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument( "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--freeze-bn", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) assert ( args.modification_flag_loss or args.modification_type_loss or args.embedding_loss ), "At least one of losses must be set" modification_flag_loss = args.modification_flag_loss modification_type_loss = args.modification_type_loss embedding_loss = args.embedding_loss feature_maps_loss = args.feature_maps_loss mask_loss = args.mask_loss bits_loss = args.bits_loss freeze_encoder = args.freeze_encoder data_dir = args.data_dir cache = args.cache num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name: str = args.model optimizer_name = args.optimizer image_size = (512, 512) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout verbose = args.verbose warmup = args.warmup show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance freeze_bn = args.freeze_bn train_batch_size = args.batch_size mixup = args.mixup cutmix = args.cutmix tsa = args.tsa fine_tune = args.fine_tune obliterate_p = args.obliterate negative_image_dir = args.negative_image_dir warmup_batch_size = args.warmup_batch_size or args.batch_size # Compute batch size for validation valid_batch_size = train_batch_size run_train = num_epochs > 0 custom_model_kwargs = {} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) if embedding_loss is not None: custom_model_kwargs["need_embedding"] = True model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() required_features = model.required_features if mask_loss is not None: required_features.append(INPUT_TRUE_MODIFICATION_MASK) if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transferring weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) if freeze_bn: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model, freeze_bn=True) print("Freezing bn params") main_metric = "loss" main_metric_minimize = True current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if mixup: checkpoint_prefix += "_mixup" if cutmix: checkpoint_prefix += "_cutmix" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [] if show: default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)] # Pretrain/warmup if warmup: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=0, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, mask_loss=mask_loss, bits_loss=bits_loss, feature_maps_loss=feature_maps_loss, num_epochs=warmup, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True) if freeze_encoder: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None) optimizer = get_optimizer( "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4 ) scheduler = None print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout, "(Non-default)" if dropout is not None else "") print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "warmup"), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth") clean_checkpoint(best_checkpoint, model_checkpoint) # Restore state of best model # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect() if run_train: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) if negative_image_dir: negatives_ds = get_negatives_ds( negative_image_dir, fold=fold, features=required_features, max_images=16536 ) train_ds = train_ds + negatives_ds train_sampler = None # TODO: Add proper support of sampler print("Adding", len(negatives_ds), "negative samples to training set") criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, mask_loss=mask_loss, bits_loss=bits_loss, num_epochs=num_epochs, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) optimizer = get_optimizer( optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") # Restore state of best model clean_checkpoint(best_checkpoint, model_checkpoint) # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect() if fine_tune: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation="light", balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, mask_loss=mask_loss, bits_loss=bits_loss, num_epochs=fine_tune, mixup=False, cutmix=False, tsa=False, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) optimizer = get_optimizer( "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "finetune"), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) del optimizer, loaders, runner, callbacks
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "--disaster-type-loss", type=str, default=None, # [["ce", 1.0]], action="append", nargs="+", help="Criterion for classifying disaster type", ) parser.add_argument( "--damage-type-loss", type=str, default=None, # [["bce", 1.0]], action="append", nargs="+", help= "Criterion for classifying presence of building with particular damage type", ) parser.add_argument("-l", "--criterion", type=str, default=None, action="append", nargs="+", help="Criterion") parser.add_argument("--mask4", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 4") parser.add_argument("--mask8", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 8") parser.add_argument("--mask16", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 16") parser.add_argument("--mask32", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 32") parser.add_argument("--embedding", type=str, default=None) parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("--fold", default=0, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument("-pl", "--pseudolabeling", type=str, required=True) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--only-buildings", action="store_true") parser.add_argument("--freeze-bn", action="store_true") parser.add_argument("--crops", action="store_true", help="Train on random crops") parser.add_argument("--post-transform", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout segmentation_losses = args.criterion verbose = args.verbose show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance only_buildings = args.only_buildings freeze_bn = args.freeze_bn train_on_crops = args.crops enable_post_image_transform = args.post_transform disaster_type_loss = args.disaster_type_loss train_batch_size = args.batch_size embedding_criterion = args.embedding damage_type_loss = args.damage_type_loss pseudolabels_dir = args.pseudolabeling # Compute batch size for validaion if train_on_crops: valid_batch_size = max(1, (train_batch_size * (image_size[0] * image_size[1])) // (1024**2)) else: valid_batch_size = train_batch_size run_train = num_epochs > 0 model: nn.Module = get_model(model_name, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) if freeze_bn: torch_utils.freeze_bn(model) print("Freezing bn params") runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None) main_metric = "weighted_f1" cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if pseudolabels_dir: checkpoint_prefix += "_pseudo" if train_on_crops: checkpoint_prefix += "_crops" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ CompetitionMetricCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="weighted_f1"), ConfusionMatrixCallback( input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, class_names=[ "land", "no_damage", "minor_damage", "major_damage", "destroyed" ], ignore_index=UNLABELED_SAMPLE, ), ] if show: default_callbacks += [ ShowPolarBatchesCallback(draw_predictions, metric=main_metric + "_batch", minimize=False) ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, fast=fast, fold=fold, balance=balance, only_buildings=only_buildings, train_on_crops=train_on_crops, crops_multiplication_factor=1, enable_post_image_transform=enable_post_image_transform, ) if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, image_size=image_size, augmentation="medium_nmd", train_on_crops=train_on_crops, enable_post_image_transform=enable_post_image_transform, pseudolabels_dir=pseudolabels_dir, ) train_ds = train_ds + unlabeled_train print("Using online pseudolabeling with ", len(unlabeled_train), "samples") loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=True, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) # Create losses for criterion in segmentation_losses: if isinstance(criterion, (list, tuple)) and len(criterion) == 2: loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion[0], 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="segmentation", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight) if args.mask4 is not None: for criterion in args.mask4: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask4", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_4_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight) if args.mask8 is not None: for criterion in args.mask8: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask8", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_8_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight) if args.mask16 is not None: for criterion in args.mask16: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask16", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_16_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight) if args.mask32 is not None: for criterion in args.mask32: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask32", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_32_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight) if disaster_type_loss is not None: callbacks += [ ConfusionMatrixCallback( input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, class_names=DISASTER_TYPES, ignore_index=UNKNOWN_DISASTER_TYPE_CLASS, prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix", ), AccuracyCallback( input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, prefix=f"{DISASTER_TYPE_KEY}/accuracy", activation="Softmax", ), ] for criterion in disaster_type_loss: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix=DISASTER_TYPE_KEY, input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, loss_weight=float(loss_weight), ignore_index=UNKNOWN_DISASTER_TYPE_CLASS, ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight) if damage_type_loss is not None: callbacks += [ # MultilabelConfusionMatrixCallback( # input_key=DAMAGE_TYPE_KEY, # output_key=DAMAGE_TYPE_KEY, # class_names=DAMAGE_TYPES, # prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix", # ), AccuracyCallback( input_key=DAMAGE_TYPE_KEY, output_key=DAMAGE_TYPE_KEY, prefix=f"{DAMAGE_TYPE_KEY}/accuracy", activation="Sigmoid", threshold=0.5, ) ] for criterion in damage_type_loss: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix=DAMAGE_TYPE_KEY, input_key=DAMAGE_TYPE_KEY, output_key=DAMAGE_TYPE_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight) if embedding_criterion is not None: cd, criterion, criterion_name = get_criterion_callback( embedding_criterion, prefix="embedding", input_key=INPUT_MASK_KEY, output_key=OUTPUT_EMBEDDING_KEY, loss_weight=1.0, ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print("Data ") print(" Augmentations :", augmentations) print(" Train size :", len(loaders["train"]), len(train_ds)) print(" Valid size :", len(loaders["valid"]), len(valid_ds)) print(" Image size :", image_size) print(" Train on crops :", train_on_crops) print(" Balance :", balance) print(" Buildings only :", only_buildings) print(" Post transform :", enable_post_image_transform) print(" Pseudolabels :", pseudolabels_dir) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print(" Criterion :", segmentation_losses) print(" Damage type :", damage_type_loss) print(" Disaster type :", disaster_type_loss) print(" Embedding :", embedding_criterion) # model training runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "opl"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": cmd_args}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "main", "checkpoints", f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) del optimizer, loaders
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--fast', action='store_true') parser.add_argument('--mixup', action='store_true') parser.add_argument('--balance', action='store_true') parser.add_argument('--balance-datasets', action='store_true') parser.add_argument('--swa', action='store_true') parser.add_argument('--show', action='store_true') parser.add_argument('--use-idrid', action='store_true') parser.add_argument('--use-messidor', action='store_true') parser.add_argument('--use-aptos2015', action='store_true') parser.add_argument('--use-aptos2019', action='store_true') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--coarse', action='store_true') parser.add_argument('-acc', '--accumulation-steps', type=int, default=1, help='Number of batches to process') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory') parser.add_argument('-m', '--model', type=str, default='resnet18_gap', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') parser.add_argument('-f', '--fold', action='append', type=int, default=None) parser.add_argument('-ft', '--fine-tune', default=0, type=int) parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--criterion-reg', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-ord', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-cls', type=str, default=['ce'], nargs='+', help='Criterion') parser.add_argument('-l1', type=float, default=0, help='L1 regularization loss') parser.add_argument('-l2', type=float, default=0, help='L2 regularization loss') parser.add_argument('-o', '--optimizer', default='Adam', help='Name of the optimizer') parser.add_argument('-p', '--preprocessing', default=None, help='Preprocessing method') parser.add_argument( '-c', '--checkpoint', type=str, default=None, help='Checkpoint filename to use as initial model weights') parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(), type=int, help='Num workers') parser.add_argument('-a', '--augmentations', default='medium', type=str, help='') parser.add_argument('-tta', '--tta', default=None, type=str, help='Type of TTA to use [fliplr, d4]') parser.add_argument('-t', '--transfer', default=None, type=str, help='') parser.add_argument('--fp16', action='store_true') parser.add_argument('-s', '--scheduler', default='multistep', type=str, help='') parser.add_argument('--size', default=512, type=int, help='Image size for training & inference') parser.add_argument('-wd', '--weight-decay', default=0, type=float, help='L2 weight decay') parser.add_argument('-wds', '--weight-decay-step', default=None, type=float, help='L2 weight decay step to add after each epoch') parser.add_argument('-d', '--dropout', default=0.0, type=float, help='Dropout before head layer') parser.add_argument( '--warmup', default=0, type=int, help= 'Number of warmup epochs with 0.1 of the initial LR and frozed encoder' ) parser.add_argument('-x', '--experiment', default=None, type=str, help='Dropout before head layer') args = parser.parse_args() data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate l1 = args.l1 l2 = args.l2 early_stopping = args.early_stopping model_name = args.model optimizer_name = args.optimizer image_size = (args.size, args.size) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 fine_tune = args.fine_tune criterion_reg_name = args.criterion_reg criterion_cls_name = args.criterion_cls criterion_ord_name = args.criterion_ord folds = args.fold mixup = args.mixup balance = args.balance balance_datasets = args.balance_datasets use_swa = args.swa show_batches = args.show scheduler_name = args.scheduler verbose = args.verbose weight_decay = args.weight_decay use_idrid = args.use_idrid use_messidor = args.use_messidor use_aptos2015 = args.use_aptos2015 use_aptos2019 = args.use_aptos2019 warmup = args.warmup dropout = args.dropout use_unsupervised = False experiment = args.experiment preprocessing = args.preprocessing weight_decay_step = args.weight_decay_step coarse_grading = args.coarse class_names = get_class_names(coarse_grading) assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor current_time = datetime.now().strftime('%b%d_%H_%M') random_name = get_random_name() if folds is None or len(folds) == 0: folds = [None] for fold in folds: torch.cuda.empty_cache() checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}' if preprocessing is not None: checkpoint_prefix += f'_{preprocessing}' if use_aptos2019: checkpoint_prefix += '_aptos2019' if use_aptos2015: checkpoint_prefix += '_aptos2015' if use_messidor: checkpoint_prefix += '_messidor' if use_idrid: checkpoint_prefix += '_idrid' if coarse_grading: checkpoint_prefix += '_coarse' if fold is not None: checkpoint_prefix += f'_fold{fold}' checkpoint_prefix += f'_{random_name}' if experiment is not None: checkpoint_prefix = experiment directory_prefix = f'{current_time}/{checkpoint_prefix}' log_dir = os.path.join('runs', directory_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json') with open(config_fname, 'w') as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) set_manual_seed(args.seed) num_classes = len(class_names) model = get_model(model_name, num_classes=num_classes, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint['model_state_dict'] for name, value in pretrained_dict.items(): try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) report_checkpoint(checkpoint) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) train_ds, valid_ds, train_sizes = get_datasets( data_dir=data_dir, use_aptos2019=use_aptos2019, use_aptos2015=use_aptos2015, use_idrid=use_idrid, use_messidor=use_messidor, use_unsupervised=False, coarse_grading=coarse_grading, image_size=image_size, augmentation=augmentations, preprocessing=preprocessing, target_dtype=int, fold=fold, folds=4) train_loader, valid_loader = get_dataloaders( train_ds, valid_ds, batch_size=batch_size, num_workers=num_workers, train_sizes=train_sizes, balance=balance, balance_datasets=balance_datasets, balance_unlabeled=False) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader print('Datasets :', data_dir) print(' Train size :', len(train_loader), len(train_loader.dataset)) print(' Valid size :', len(valid_loader), len(valid_loader.dataset)) print(' Aptos 2019 :', use_aptos2019) print(' Aptos 2015 :', use_aptos2015) print(' IDRID :', use_idrid) print(' Messidor :', use_messidor) print('Train session :', directory_prefix) print(' FP16 mode :', fp16) print(' Fast mode :', fast) print(' Mixup :', mixup) print(' Balance cls. :', balance) print(' Balance ds. :', balance_datasets) print(' Warmup epoch :', warmup) print(' Train epochs :', num_epochs) print(' Fine-tune ephs :', fine_tune) print(' Workers :', num_workers) print(' Fold :', fold) print(' Log dir :', log_dir) print(' Augmentations :', augmentations) print('Model :', model_name) print(' Parameters :', count_parameters(model)) print(' Image size :', image_size) print(' Dropout :', dropout) print(' Classes :', class_names, num_classes) print('Optimizer :', optimizer_name) print(' Learning rate :', learning_rate) print(' Batch size :', batch_size) print(' Criterion (cls):', criterion_cls_name) print(' Criterion (reg):', criterion_reg_name) print(' Criterion (ord):', criterion_ord_name) print(' Scheduler :', scheduler_name) print(' Weight decay :', weight_decay, weight_decay_step) print(' L1 reg. :', l1) print(' L2 reg. :', l2) print(' Early stopping :', early_stopping) # model training callbacks = [] criterions = {} main_metric = 'cls/kappa' if criterion_reg_name is not None: cb, crits = get_reg_callbacks(criterion_reg_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_ord_name is not None: cb, crits = get_ord_callbacks(criterion_ord_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_cls_name is not None: cb, crits = get_cls_callbacks(criterion_cls_name, num_classes=num_classes, num_epochs=num_epochs, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if l1 > 0: callbacks += [ LPRegularizationCallback(start_wd=l1, end_wd=l1, schedule=None, prefix='l1', p=1) ] if l2 > 0: callbacks += [ LPRegularizationCallback(start_wd=l2, end_wd=l2, schedule=None, prefix='l2', p=2) ] callbacks += [CustomOptimizerCallback()] runner = SupervisedRunner(input_key='image') # Pretrain/warmup if warmup: set_trainable(model.encoder, False, False) optimizer = get_optimizer('Adam', get_optimizable_parameters(model), learning_rate=learning_rate * 0.1) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'warmup'), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer # Main train if num_epochs: set_trainable(model.encoder, True, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay) if use_swa: from torchcontrib.optim import SWA optimizer = SWA(optimizer, swa_start=len(train_loader), swa_freq=512) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(train_loader)) # Additional callbacks that specific to main stage only added here to copy of callbacks main_stage_callbacks = callbacks if early_stopping: es_callback = EarlyStoppingCallback(early_stopping, min_delta=1e-4, metric=main_metric, minimize=False) main_stage_callbacks = callbacks + [es_callback] runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=main_stage_callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'main'), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer, scheduler best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint) # Restoring best model from checkpoint checkpoint = load_checkpoint(best_checkpoint) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) # Stage 3 - Fine tuning if fine_tune: set_trainable(model.encoder, False, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate) scheduler = get_scheduler('multistep', optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(train_loader)) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'finetune'), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint)