def test_callback_wrapping(): """ Test on callback wrapping for GanExperiment class. """ model = torch.nn.Module() dataset = torch.utils.data.Dataset() dataloader = torch.utils.data.DataLoader(dataset) loaders = OrderedDict() loaders["train"] = dataloader # Prepare callbacks and state kwargs discriminator_loss_key = "loss_d" generator_loss_key = "loss_g" discriminator_key = "discriminator" generator_key = "generator" input_callbacks = OrderedDict( { "optim_d": OptimizerCallback( loss_key=discriminator_loss_key, optimizer_key=discriminator_key ), "optim_g": OptimizerCallback( loss_key=generator_loss_key, optimizer_key=generator_key ), "tensorboard": TensorboardLogger(), } ) state_kwargs = { "discriminator_train_phase": "discriminator_train", "discriminator_train_num": 1, "generator_train_phase": "generator_train", "generator_train_num": 5, } discriminator_callbacks = ["optim_d"] generator_callbacks = ["optim_g"] phase2callbacks = { state_kwargs["discriminator_train_phase"]: discriminator_callbacks, state_kwargs["generator_train_phase"]: generator_callbacks, } exp = GanExperiment( model=model, loaders=loaders, callbacks=input_callbacks, state_kwargs=state_kwargs, phase2callbacks=phase2callbacks, ) callbacks = exp.get_callbacks("train") assert "optim_d" in callbacks.keys() assert "optim_g" in callbacks.keys() assert "tensorboard" in callbacks.keys() assert isinstance(callbacks["phase_manager"], PhaseManagerCallback) assert isinstance(callbacks["optim_d"], PhaseWrapperCallback) assert isinstance(callbacks["optim_g"], PhaseWrapperCallback) assert isinstance(callbacks["tensorboard"], TensorboardLogger)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration") parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("--cache", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-m", "--model", type=str, default="resnet34", help="") parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64") parser.add_argument( "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64" ) parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") parser.add_argument( "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement" ) parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument( "--modification-type-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" ) parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--mixup", action="store_true") parser.add_argument("--cutmix", action="store_true") parser.add_argument("--tsa", action="store_true") parser.add_argument("--fold", default=None, type=int) parser.add_argument("-s", "--scheduler", default=None, type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument( "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--freeze-bn", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) assert ( args.modification_flag_loss or args.modification_type_loss or args.embedding_loss ), "At least one of losses must be set" modification_flag_loss = args.modification_flag_loss modification_type_loss = args.modification_type_loss embedding_loss = args.embedding_loss feature_maps_loss = args.feature_maps_loss mask_loss = args.mask_loss bits_loss = args.bits_loss freeze_encoder = args.freeze_encoder data_dir = args.data_dir cache = args.cache num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name: str = args.model optimizer_name = args.optimizer image_size = (512, 512) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout verbose = args.verbose warmup = args.warmup show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance freeze_bn = args.freeze_bn train_batch_size = args.batch_size mixup = args.mixup cutmix = args.cutmix tsa = args.tsa fine_tune = args.fine_tune obliterate_p = args.obliterate negative_image_dir = args.negative_image_dir warmup_batch_size = args.warmup_batch_size or args.batch_size # Compute batch size for validation valid_batch_size = train_batch_size run_train = num_epochs > 0 custom_model_kwargs = {} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) if embedding_loss is not None: custom_model_kwargs["need_embedding"] = True model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() required_features = model.required_features if mask_loss is not None: required_features.append(INPUT_TRUE_MODIFICATION_MASK) if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transferring weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) if freeze_bn: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model, freeze_bn=True) print("Freezing bn params") main_metric = "loss" main_metric_minimize = True current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if mixup: checkpoint_prefix += "_mixup" if cutmix: checkpoint_prefix += "_cutmix" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [] if show: default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)] # Pretrain/warmup if warmup: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=0, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, mask_loss=mask_loss, bits_loss=bits_loss, feature_maps_loss=feature_maps_loss, num_epochs=warmup, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True) if freeze_encoder: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None) optimizer = get_optimizer( "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4 ) scheduler = None print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout, "(Non-default)" if dropout is not None else "") print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "warmup"), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth") clean_checkpoint(best_checkpoint, model_checkpoint) # Restore state of best model # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect() if run_train: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) if negative_image_dir: negatives_ds = get_negatives_ds( negative_image_dir, fold=fold, features=required_features, max_images=16536 ) train_ds = train_ds + negatives_ds train_sampler = None # TODO: Add proper support of sampler print("Adding", len(negatives_ds), "negative samples to training set") criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, mask_loss=mask_loss, bits_loss=bits_loss, num_epochs=num_epochs, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) optimizer = get_optimizer( optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") # Restore state of best model clean_checkpoint(best_checkpoint, model_checkpoint) # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect() if fine_tune: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, augmentation="light", balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, mask_loss=mask_loss, bits_loss=bits_loss, num_epochs=fine_tune, mixup=False, cutmix=False, tsa=False, ) callbacks = ( default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, } ), ] ) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) optimizer = get_optimizer( "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "finetune"), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) del optimizer, loaders, runner, callbacks
def main(): parser = argparse.ArgumentParser() ########################################################################################### # Distributed-training related stuff parser.add_argument("--local_rank", type=int, default=0) ########################################################################################### parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument( "-dd", "--data-dir", type=str, help="Data directory for INRIA sattelite dataset", default=os.environ.get("INRIA_DATA_DIR"), ) parser.add_argument("-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() args.is_master = args.local_rank == 0 args.distributed = False fp16 = args.fp16 if "WORLD_SIZE" in os.environ: args.distributed = int(os.environ["WORLD_SIZE"]) > 1 args.world_size = int(os.environ["WORLD_SIZE"]) # args.world_size = torch.distributed.get_world_size() print("Initializing init_process_group", args.local_rank) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl") print("Initialized init_process_group", args.local_rank) is_master = args.is_master | (not args.distributed) distributed_params = {} if args.distributed: distributed_params = {"rank": args.local_rank, "syncbn": True} if args.fp16: distributed_params["apex"] = True distributed_params["opt_level"] = "O1" set_manual_seed(args.seed + args.local_rank) catalyst.utils.set_global_seed(args.seed + args.local_rank) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False data_dir = args.data_dir if data_dir is None: raise ValueError("--data-dir must be set") num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion verbose = args.verbose show = args.show use_dsv = args.dsv accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) custom_model_kwargs = {} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) main_metric = "optimized_jaccard" current_time = datetime.now().strftime("%y%m%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment if args.distributed: checkpoint_prefix += f"_local_rank_{args.local_rank}" log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), JaccardMetricPerImageWithOptimalThreshold(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard") ] if is_master: default_callbacks += [ BestMetricCheckpointCallback(target_metric="optimized_jaccard", target_metric_minimize=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": args.size, "weight_decay": weight_decay, "epochs": num_epochs, "dropout": None if dropout is None else float(dropout) }), ] if show and is_master: visualize_inria_predictions = partial( draw_inria_predictions, image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, max_images=16, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False), ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True), ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, buildings_only=(train_mode == "tiles"), fast=fast, need_weight_mask=need_weight_mask, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset(data_dir, include_masks=False, augmentation=None, image_size=image_size) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size) loaders["label"] = DataLoader(unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="label", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") valid_sampler = None if args.distributed: if train_sampler is not None: train_sampler = DistributedSamplerWrapper(train_sampler, args.world_size, args.local_rank, shuffle=True) else: train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True) valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False) loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler) # Create losses for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix=f"{INPUT_MASK_KEY}/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) if use_dsv: print("Using DSV") criterions = "dsv" dsv_loss_name = "soft_bce" criterions_dict[criterions] = AdaptiveMaskLoss2d( get_loss(dsv_loss_name, ignore_index=ignore_index)) for i, dsv_input in enumerate([ OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY, OUTPUT_MASK_32_KEY ]): criterion_callback = CriterionCallback( prefix=f"{dsv_input}/" + dsv_loss_name, input_key=INPUT_MASK_KEY, output_key=dsv_input, criterion_key=criterions, multiplier=1.0, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) if isinstance(model, SupervisedHGSegmentationModel): print("Using Hourglass DSV") dsv_loss_name = "kl" criterions_dict["dsv"] = get_loss(dsv_loss_name, ignore_index=ignore_index) num_supervision_inputs = model.encoder.num_blocks - 1 dsv_outputs = [ OUTPUT_MASK_4_KEY + "_after_hg_" + str(i) for i in range(num_supervision_inputs) ] for i, dsv_input in enumerate(dsv_outputs): criterion_callback = CriterionCallback( prefix="supervision/" + dsv_input, input_key=INPUT_MASK_KEY, output_key=dsv_input, criterion_key="dsv", multiplier=(i + 1) / num_supervision_inputs, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) callbacks += [ MetricAggregationCallback(prefix="loss", metrics=losses, mode="sum"), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Train mode :", train_mode) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Augmentations :", augmentations) print(" Train size :", "batches", len(loaders["train"]), "dataset", len(train_ds)) print(" Valid size :", "batches", len(loaders["valid"]), "dataset", len(valid_ds)) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Image size :", image_size) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Batch size :", batch_size) print(" Criterion :", criterions) print(" Use weight mask:", need_weight_mask) if args.distributed: print("Distributed") print(" World size :", args.world_size) print(" Local rank :", args.local_rank) print(" Is master :", args.is_master) # model training runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") runner.train( fp16=distributed_params, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict(model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask) del optimizer, loaders
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "--disaster-type-loss", type=str, default=None, # [["ce", 1.0]], action="append", nargs="+", help="Criterion for classifying disaster type", ) parser.add_argument( "--damage-type-loss", type=str, default=None, # [["bce", 1.0]], action="append", nargs="+", help= "Criterion for classifying presence of building with particular damage type", ) parser.add_argument("-l", "--criterion", type=str, default=None, action="append", nargs="+", help="Criterion") parser.add_argument("--mask4", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 4") parser.add_argument("--mask8", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 8") parser.add_argument("--mask16", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 16") parser.add_argument("--mask32", type=str, default=None, action="append", nargs="+", help="Criterion for mask with stride 32") parser.add_argument("--embedding", type=str, default=None) parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("--fold", default=0, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument("-pl", "--pseudolabeling", type=str, required=True) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--only-buildings", action="store_true") parser.add_argument("--freeze-bn", action="store_true") parser.add_argument("--crops", action="store_true", help="Train on random crops") parser.add_argument("--post-transform", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout segmentation_losses = args.criterion verbose = args.verbose show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance only_buildings = args.only_buildings freeze_bn = args.freeze_bn train_on_crops = args.crops enable_post_image_transform = args.post_transform disaster_type_loss = args.disaster_type_loss train_batch_size = args.batch_size embedding_criterion = args.embedding damage_type_loss = args.damage_type_loss pseudolabels_dir = args.pseudolabeling # Compute batch size for validaion if train_on_crops: valid_batch_size = max(1, (train_batch_size * (image_size[0] * image_size[1])) // (1024**2)) else: valid_batch_size = train_batch_size run_train = num_epochs > 0 model: nn.Module = get_model(model_name, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) if freeze_bn: torch_utils.freeze_bn(model) print("Freezing bn params") runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None) main_metric = "weighted_f1" cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if pseudolabels_dir: checkpoint_prefix += "_pseudo" if train_on_crops: checkpoint_prefix += "_crops" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ CompetitionMetricCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="weighted_f1"), ConfusionMatrixCallback( input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, class_names=[ "land", "no_damage", "minor_damage", "major_damage", "destroyed" ], ignore_index=UNLABELED_SAMPLE, ), ] if show: default_callbacks += [ ShowPolarBatchesCallback(draw_predictions, metric=main_metric + "_batch", minimize=False) ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, fast=fast, fold=fold, balance=balance, only_buildings=only_buildings, train_on_crops=train_on_crops, crops_multiplication_factor=1, enable_post_image_transform=enable_post_image_transform, ) if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, image_size=image_size, augmentation="medium_nmd", train_on_crops=train_on_crops, enable_post_image_transform=enable_post_image_transform, pseudolabels_dir=pseudolabels_dir, ) train_ds = train_ds + unlabeled_train print("Using online pseudolabeling with ", len(unlabeled_train), "samples") loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=True, ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) # Create losses for criterion in segmentation_losses: if isinstance(criterion, (list, tuple)) and len(criterion) == 2: loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion[0], 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="segmentation", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight) if args.mask4 is not None: for criterion in args.mask4: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask4", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_4_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight) if args.mask8 is not None: for criterion in args.mask8: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask8", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_8_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight) if args.mask16 is not None: for criterion in args.mask16: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask16", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_16_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight) if args.mask32 is not None: for criterion in args.mask32: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix="mask32", input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_32_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight) if disaster_type_loss is not None: callbacks += [ ConfusionMatrixCallback( input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, class_names=DISASTER_TYPES, ignore_index=UNKNOWN_DISASTER_TYPE_CLASS, prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix", ), AccuracyCallback( input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, prefix=f"{DISASTER_TYPE_KEY}/accuracy", activation="Softmax", ), ] for criterion in disaster_type_loss: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix=DISASTER_TYPE_KEY, input_key=DISASTER_TYPE_KEY, output_key=DISASTER_TYPE_KEY, loss_weight=float(loss_weight), ignore_index=UNKNOWN_DISASTER_TYPE_CLASS, ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight) if damage_type_loss is not None: callbacks += [ # MultilabelConfusionMatrixCallback( # input_key=DAMAGE_TYPE_KEY, # output_key=DAMAGE_TYPE_KEY, # class_names=DAMAGE_TYPES, # prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix", # ), AccuracyCallback( input_key=DAMAGE_TYPE_KEY, output_key=DAMAGE_TYPE_KEY, prefix=f"{DAMAGE_TYPE_KEY}/accuracy", activation="Sigmoid", threshold=0.5, ) ] for criterion in damage_type_loss: if isinstance(criterion, (list, tuple)): loss_name, loss_weight = criterion else: loss_name, loss_weight = criterion, 1.0 cd, criterion, criterion_name = get_criterion_callback( loss_name, prefix=DAMAGE_TYPE_KEY, input_key=DAMAGE_TYPE_KEY, output_key=DAMAGE_TYPE_KEY, loss_weight=float(loss_weight), ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight) if embedding_criterion is not None: cd, criterion, criterion_name = get_criterion_callback( embedding_criterion, prefix="embedding", input_key=INPUT_MASK_KEY, output_key=OUTPUT_EMBEDDING_KEY, loss_weight=1.0, ) criterions_dict.update(cd) callbacks.append(criterion) losses.append(criterion_name) print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print("Data ") print(" Augmentations :", augmentations) print(" Train size :", len(loaders["train"]), len(train_ds)) print(" Valid size :", len(loaders["valid"]), len(valid_ds)) print(" Image size :", image_size) print(" Train on crops :", train_on_crops) print(" Balance :", balance) print(" Buildings only :", only_buildings) print(" Post transform :", enable_post_image_transform) print(" Pseudolabels :", pseudolabels_dir) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print(" Criterion :", segmentation_losses) print(" Damage type :", damage_type_loss) print(" Disaster type :", disaster_type_loss) print(" Embedding :", embedding_criterion) # model training runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "opl"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": cmd_args}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "main", "checkpoints", f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) del optimizer, loaders
def run_stage_training(model: Union[TimmRgbModel, YCrCbModel], config: StageConfig, exp_config: ExperimenetConfig, experiment_dir: str): # Preparing model freeze_model(model, freeze_bn=config.freeze_bn) train_ds, valid_ds, train_sampler = get_datasets( data_dir=exp_config.data_dir, image_size=config.image_size, augmentation=config.augmentations, balance=config.balance, fast=config.fast, fold=exp_config.fold, features=model.required_features, obliterate_p=config.obliterate_p, ) criterions_dict, loss_callbacks = get_criterions( modification_flag=config.modification_flag_loss, modification_type=config.modification_type_loss, embedding_loss=config.embedding_loss, feature_maps_loss=config.feature_maps_loss, num_epochs=config.epochs, mixup=config.mixup, cutmix=config.cutmix, tsa=config.tsa, ) callbacks = loss_callbacks + [ OptimizerCallback(accumulation_steps=config.accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": exp_config.model_name, "scheduler": config.schedule, "optimizer": config.optimizer, "augmentations": config.augmentations, "size": config.image_size[0], "weight_decay": config.weight_decay, }), ] if config.show: callbacks += [ ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True) ] loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=config.train_batch_size, num_workers=exp_config.num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=config.valid_batch_size, num_workers=exp_config.num_workers, pin_memory=True) print("Stage :", config.stage_name) print(" FP16 mode :", config.fp16) print(" Fast mode :", config.fast) print(" Epochs :", config.epochs) print(" Workers :", exp_config.num_workers) print(" Data dir :", exp_config.data_dir) print(" Experiment dir :", experiment_dir) print("Data ") print(" Augmentations :", config.augmentations) print(" Obliterate (%) :", config.obliterate_p) print(" Negative images:", config.negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", config.image_size) print(" Balance :", config.balance) print(" Mixup :", config.mixup) print(" CutMix :", config.cutmix) print(" TSA :", config.tsa) print("Model :", exp_config.model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", exp_config.dropout) print("Optimizer :", config.optimizer) print(" Learning rate :", config.learning_rate) print(" Weight decay :", config.weight_decay) print(" Scheduler :", config.schedule) print(" Batch sizes :", config.train_batch_size, config.valid_batch_size) print("Losses ") print(" Flag :", config.modification_flag_loss) print(" Type :", config.modification_type_loss) print(" Embedding :", config.embedding_loss) print(" Feature maps :", config.feature_maps_loss) optimizer = get_optimizer( config.optimizer, get_optimizable_parameters(model), learning_rate=config.learning_rate, weight_decay=config.weight_decay, ) scheduler = get_scheduler( config.schedule, optimizer, lr=config.learning_rate, num_epochs=config.epochs, batches_in_epoch=len(loaders["train"]), ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=model.required_features, output_key=None) runner.train( fp16=config.fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(experiment_dir, config.stage_name), num_epochs=config.epochs, verbose=config.verbose, main_metric=config.main_metric, minimize_metric=config.main_metric_minimize, checkpoint_data={"config": config}, ) del optimizer, loaders, callbacks, runner best_checkpoint = os.path.join(experiment_dir, config.stage_name, "checkpoints", "best.pth") model_checkpoint = os.path.join(experiment_dir, f"{exp_config.checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) # Restore state of best model if config.restore_best: unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) # Some memory cleanup torch.cuda.empty_cache() gc.collect()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory') parser.add_argument('-l', '--loss', type=str, default='label_smooth_cross_entropy') parser.add_argument('-t1', '--temper1', type=float, default=0.2) parser.add_argument('-t2', '--temper2', type=float, default=4.0) parser.add_argument('-optim', '--optimizer', type=str, default='adam') parser.add_argument('-prep', '--prep_function', type=str, default='none') parser.add_argument('--train_on_different_datasets', action='store_true') parser.add_argument('--use-current', action='store_true') parser.add_argument('--use-extra', action='store_true') parser.add_argument('--use-unlabeled', action='store_true') parser.add_argument('--fast', action='store_true') parser.add_argument('--mixup', action='store_true') parser.add_argument('--balance', action='store_true') parser.add_argument('--balance-datasets', action='store_true') parser.add_argument('--show', action='store_true') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('-m', '--model', type=str, default='efficientnet-b4', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-s', '--sizes', default=380, type=int, help='Image size for training & inference') parser.add_argument('-f', '--fold', type=int, default=None) parser.add_argument('-t', '--transfer', default=None, type=str, help='') parser.add_argument('-lr', '--learning_rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('-a', '--augmentations', default='medium', type=str, help='') parser.add_argument('-accum', '--accum-step', type=int, default=1) parser.add_argument('-metric', '--metric', type=str, default='accuracy01') args = parser.parse_args() diff_dataset_train = args.train_on_different_datasets data_dir = args.data_dir epochs = args.epochs batch_size = args.batch_size seed = args.seed loss_name = args.loss optim_name = args.optimizer prep_function = args.prep_function model_name = args.model size = args.sizes, print(size) print(size[0]) image_size = (size[0], size[0]) print(image_size) fast = args.fast fold = args.fold mixup = args.mixup balance = args.balance balance_datasets = args.balance_datasets show_batches = args.show verbose = args.verbose use_current = args.use_current use_extra = args.use_extra use_unlabeled = args.use_unlabeled learning_rate = args.learning_rate augmentations = args.augmentations transfer = args.transfer accum_step = args.accum_step #cosine_loss accuracy01 main_metric = args.metric print(data_dir) num_classes = 5 assert use_current or use_extra print(fold) current_time = datetime.now().strftime('%b%d_%H_%M') random_name = get_random_name() current_time = datetime.now().strftime('%b%d_%H_%M') random_name = get_random_name() # if folds is None or len(folds) == 0: # folds = [None] torch.cuda.empty_cache() checkpoint_prefix = f'{model_name}_{size}_{augmentations}' if transfer is not None: checkpoint_prefix += '_pretrain_from_' + str(transfer) else: if use_current: checkpoint_prefix += '_current' if use_extra: checkpoint_prefix += '_extra' if use_unlabeled: checkpoint_prefix += '_unlabeled' if fold is not None: checkpoint_prefix += f'_fold{fold}' directory_prefix = f'{current_time}_{checkpoint_prefix}' log_dir = os.path.join('runs', directory_prefix) os.makedirs(log_dir, exist_ok=False) set_manual_seed(seed) model = get_model(model_name) if transfer is not None: print("Transfering weights from model checkpoint") model.load_state_dict(torch.load(transfer)['model_state_dict']) model = model.cuda() if diff_dataset_train: train_on = ['current_train', 'extra_train'] valid_on = ['unlabeled'] train_ds, valid_ds, train_sizes = get_datasets_universal( train_on=train_on, valid_on=valid_on, image_size=image_size, augmentation=augmentations, target_dtype=int, prep_function=prep_function) else: train_ds, valid_ds, train_sizes = get_datasets( data_dir=data_dir, use_current=use_current, use_extra=use_extra, image_size=image_size, prep_function=prep_function, augmentation=augmentations, target_dtype=int, fold=fold, folds=5) train_loader, valid_loader = get_dataloaders(train_ds, valid_ds, batch_size=batch_size, train_sizes=train_sizes, num_workers=6, balance=True, balance_datasets=True, balance_unlabeled=False) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader runner = SupervisedRunner(input_key='image') criterions = get_loss(loss_name) # criterions_tempered = TemperedLogLoss() # optimizer = catalyst.contrib.nn.optimizers.radam.RAdam(model.parameters(), lr = learning_rate) optimizer = get_optim(optim_name, model, learning_rate) # optimizer = catalyst.contrib.nn.optimizers.Adam(model.parameters(), lr = learning_rate) # criterions = nn.CrossEntropyLoss() # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25], gamma=0.8) # cappa = CappaScoreCallback() Q = math.floor(len(train_ds) / batch_size) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Q) if main_metric != 'accuracy01': callbacks = [ AccuracyCallback(num_classes=num_classes), CosineLossCallback(), OptimizerCallback(accumulation_steps=accum_step), CheckpointCallback(save_n_best=epochs) ] else: callbacks = [ AccuracyCallback(num_classes=num_classes), OptimizerCallback(accumulation_steps=accum_step), CheckpointCallback(save_n_best=epochs) ] # main_metric = 'accuracy01' runner.train( fp16=True, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=log_dir, num_epochs=epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, )
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset") parser.add_argument("-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion verbose = args.verbose warmup = args.warmup show = args.show use_dsv = args.dsv accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) model: nn.Module = get_model(model_name, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") main_metric = "optimized_jaccard" cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [ PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), OptimalThreshold(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard"), # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid), ] if show: visualize_inria_predictions = partial( draw_inria_predictions, image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False) ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, fast=fast, need_weight_mask=need_weight_mask, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") # Pretrain/warmup if warmup: callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix="seg_loss/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] parameters = get_lr_decay_parameters(model.named_parameters(), learning_rate, {"encoder": 0.1}) optimizer = get_optimizer("RAdam", parameters, learning_rate=learning_rate * 0.1) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=False, drop_last=False) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "warmup"), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": cmd_args}, ) del optimizer, loaders best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", f"{checkpoint_prefix}_warmup.pth") clean_checkpoint(best_checkpoint, model_checkpoint) torch.cuda.empty_cache() gc.collect() if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset(data_dir, include_masks=False, augmentation=None, image_size=image_size) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size) loaders["label"] = DataLoader(unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="label", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True) # Create losses for loss_name, loss_weight in criterions: criterion_callback = CriterionCallback( prefix="seg_loss/" + loss_name, input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], output_key=OUTPUT_MASK_KEY, criterion_key=loss_name, multiplier=float(loss_weight), ) criterions_dict[loss_name] = get_loss(loss_name, ignore_index=ignore_index) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) print("Using loss", loss_name, loss_weight) if use_dsv: print("Using DSV") criterions = "dsv" dsv_loss_name = "soft_bce" criterions_dict[criterions] = AdaptiveMaskLoss2d( get_loss(dsv_loss_name, ignore_index=ignore_index)) for i, dsv_input in enumerate([ OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY, OUTPUT_MASK_32_KEY ]): criterion_callback = CriterionCallback( prefix="seg_loss_dsv/" + dsv_input, input_key=OUTPUT_MASK_KEY, output_key=dsv_input, criterion_key=criterions, multiplier=1.0, ) callbacks.append(criterion_callback) losses.append(criterion_callback.prefix) callbacks += [ CriterionAggregatorCallback(prefix="loss", loss_keys=losses), OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), ] optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] print("Train session :", checkpoint_prefix) print("\tFP16 mode :", fp16) print("\tFast mode :", args.fast) print("\tTrain mode :", train_mode) print("\tEpochs :", num_epochs) print("\tWorkers :", num_workers) print("\tData dir :", data_dir) print("\tLog dir :", log_dir) print("\tAugmentations :", augmentations) print("\tTrain size :", len(loaders["train"]), len(train_ds)) print("\tValid size :", len(loaders["valid"]), len(valid_ds)) print("Model :", model_name) print("\tParameters :", count_parameters(model)) print("\tImage size :", image_size) print("Optimizer :", optimizer_name) print("\tLearning rate :", learning_rate) print("\tBatch size :", batch_size) print("\tCriterion :", criterions) print("\tUse weight mask:", need_weight_mask) # model training runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, "main", "checkpoints", f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict(model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask) del optimizer, loaders
def main(): parser = argparse.ArgumentParser() parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration") parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("--cache", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64") parser.add_argument( "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64" ) parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") parser.add_argument( "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement" ) parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument( "--modification-type-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" ) parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--mixup", action="store_true") parser.add_argument("--cutmix", action="store_true") parser.add_argument("--tsa", action="store_true") parser.add_argument("--fold", default=None, type=int) parser.add_argument("-s", "--scheduler", default=None, type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0, type=float, help="Dropout before head layer") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument( "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" ) parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--freeze-bn", action="store_true") args = parser.parse_args() set_manual_seed(args.seed) assert ( args.modification_flag_loss or args.modification_type_loss or args.embedding_loss ), "At least one of losses must be set" modification_flag_loss = args.modification_flag_loss modification_type_loss = args.modification_type_loss embedding_loss = args.embedding_loss feature_maps_loss = args.feature_maps_loss mask_loss = args.mask_loss bits_loss = args.bits_loss data_dir = args.data_dir cache = args.cache num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate optimizer_name = args.optimizer fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout verbose = args.verbose accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay balance = args.balance freeze_bn = args.freeze_bn train_batch_size = args.batch_size mixup = args.mixup cutmix = args.cutmix tsa = args.tsa obliterate_p = args.obliterate negative_image_dir = args.negative_image_dir # Compute batch size for validation valid_batch_size = train_batch_size current_time = datetime.now().strftime("%b%d_%H_%M") main_metric = "loss" main_metric_minimize = True x_train = np.load(f"embeddings_x_train_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy") y_train = np.load(f"embeddings_y_train_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy") x_valid = np.load(f"embeddings_x_holdout_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy") y_valid = np.load(f"embeddings_y_holdout_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy") print(x_train.shape, x_valid.shape) print(np.bincount(y_train), np.bincount(y_valid)) train_ds = StackerDataset(x_train, y_train) valid_ds = StackerDataset(x_valid, y_valid) criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=None, feature_maps_loss=None, mask_loss=None, bits_loss=None, num_epochs=num_epochs, mixup=mixup, cutmix=None, tsa=tsa, ) callbacks = loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "weight_decay": weight_decay, } ), ] loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=True ) loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True) model: nn.Module = StackingModel(x_train.shape[1], dropout=dropout).cuda() optimizer = get_optimizer( optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay ) scheduler = get_scheduler( scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) ) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] checkpoint_prefix = f"{current_time}_stacking" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if mixup: checkpoint_prefix += "_mixup" if cutmix: checkpoint_prefix += "_cutmix" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) print("Train session :", checkpoint_prefix) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) # print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print(" Mask :", mask_loss) print(" Bits :", bits_loss) # model training runner = SupervisedRunner(input_key=[INPUT_EMBEDDING_KEY], output_key=None) runner.train( fp16=fp16, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") # Restore state of best model clean_checkpoint(best_checkpoint, model_checkpoint)
def main(cfg: DictConfig): cwd = Path(get_original_cwd()) # overwrite config if continue training from checkpoint resume_cfg = None if "resume" in cfg: cfg_path = cwd / cfg.resume / ".hydra/config.yaml" print(f"Continue from: {cfg.resume}") # Overwrite everything except device # TODO config merger (perhaps continue training with the same optimizer but other lrs?) resume_cfg = OmegaConf.load(cfg_path) cfg.model = resume_cfg.model if cfg.train.num_epochs == 0: cfg.data.scale_factor = resume_cfg.data.scale_factor OmegaConf.save(cfg, ".hydra/config.yaml") print(OmegaConf.to_yaml(cfg)) device = set_device_id(cfg.device) set_seed(cfg.seed, device=device) # Augmentations if cfg.data.aug == "auto": transforms = albu.load(cwd / "autoalbument/autoconfig.json") else: transforms = D.get_training_augmentations() if OmegaConf.is_missing(cfg.model, "convert_bottleneck"): cfg.model.convert_bottleneck = (0, 0, 0) # Model print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} " f"convert_bn={cfg.model.convert_bn} " f"convert_bottleneck={cfg.model.convert_bottleneck} ") model = get_segmentation_model( arch=cfg.model.arch, encoder_name=cfg.model.encoder_name, encoder_weights=cfg.model.encoder_weights, classes=1, convert_bn=cfg.model.convert_bn, convert_bottleneck=cfg.model.convert_bottleneck, # decoder_attention_type="scse", # TODO to config ) model = model.to(device) model.train() print(model) # Optimization # Reduce LR for pretrained encoder layerwise_params = { "encoder*": dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder) } model_params = cutils.process_model_params( model, layerwise_params=layerwise_params) # Select optimizer optimizer = get_optimizer( name=cfg.optim.name, model_params=model_params, lr=cfg.optim.lr, wd=cfg.optim.wd, lookahead=cfg.optim.lookahead, ) criterion = { "dice": DiceLoss(), # "dice": SoftDiceLoss(mode="binary", smooth=1e-7), "iou": IoULoss(), "bce": nn.BCEWithLogitsLoss(), "lovasz": LovaszLossBinary(), "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75), } # Load states if resuming training if "resume" in cfg: checkpoint_path = (cwd / cfg.resume / cfg.train.logdir / "checkpoints/best_full.pth") if checkpoint_path.exists(): print(f"\nLoading checkpoint {str(checkpoint_path)}") checkpoint = cutils.load_checkpoint(checkpoint_path) cutils.unpack_checkpoint( checkpoint=checkpoint, model=model, optimizer=optimizer if resume_cfg.optim.name == cfg.optim.name else None, criterion=criterion, ) else: raise ValueError("Nothing to resume, checkpoint missing") # We could only want to validate resume, in this case skip training routine best_th = 0.5 stats = None if cfg.data.stats: print(f"Use statistics from file: {cfg.data.stats}") stats = cwd / cfg.data.stats if cfg.train.num_epochs is not None: callbacks = [ # Each criterion is calculated separately. CriterionCallback(input_key="mask", prefix="loss_dice", criterion_key="dice"), CriterionCallback(input_key="mask", prefix="loss_iou", criterion_key="iou"), CriterionCallback(input_key="mask", prefix="loss_bce", criterion_key="bce"), CriterionCallback(input_key="mask", prefix="loss_lovasz", criterion_key="lovasz"), CriterionCallback( input_key="mask", prefix="loss_focal_tversky", criterion_key="focal_tversky", ), # And only then we aggregate everything into one loss. MetricAggregationCallback( prefix="loss", mode="weighted_sum", # can be "sum", "weighted_sum" or "mean" # because we want weighted sum, we need to add scale for each loss metrics={ "loss_dice": cfg.loss.dice, "loss_iou": cfg.loss.iou, "loss_bce": cfg.loss.bce, "loss_lovasz": cfg.loss.lovasz, "loss_focal_tversky": cfg.loss.focal_tversky, }, ), # metrics DiceCallback(input_key="mask"), IouCallback(input_key="mask"), # gradient accumulation OptimizerCallback(accumulation_steps=cfg.optim.accumulate), # early stopping SchedulerCallback(reduced_metric="loss_dice", mode=cfg.scheduler.mode), EarlyStoppingCallback(**cfg.scheduler.early_stopping, minimize=False), # TODO WandbLogger works poorly with multistage right now WandbLogger(project=cfg.project, config=dict(cfg)), # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best), ] # Training runner = SupervisedRunner(device=device, input_key="image", input_target_key="mask") # TODO Scheduler does not work now, every stage restarts from base lr scheduler_warm_restart = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[1, 2], gamma=10, ) for i, (size, num_epochs) in enumerate( zip(cfg.data.sizes, cfg.train.num_epochs)): scale = size / 1024 print( f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}" ) # Datasets ( train_ds, valid_ds, train_images, val_images, ) = D.get_train_valid_datasets_from_path( # path=(cwd / cfg.data.path), path=(cwd / f"data/hubmap-{size}x{size}/"), train_ids=cfg.data.train_ids, valid_ids=cfg.data.valid_ids, seed=cfg.seed, valid_split=cfg.data.valid_split, mean=cfg.data.mean, std=cfg.data.std, transforms=transforms, stats=stats, ) train_bs = int(cfg.loader.train_bs / (scale**2)) valid_bs = int(cfg.loader.valid_bs / (scale**2)) print( f"train: {len(train_ds)}; bs {train_bs}", f"valid: {len(valid_ds)}, bs {valid_bs}", ) # Data loaders data_loaders = D.get_data_loaders( train_ds=train_ds, valid_ds=valid_ds, train_bs=train_bs, valid_bs=valid_bs, num_workers=cfg.loader.num_workers, ) # Select scheduler scheduler = get_scheduler( name=cfg.scheduler.type, optimizer=optimizer, num_epochs=num_epochs * (len(data_loaders["train"]) if cfg.scheduler.mode == "batch" else 1), eta_min=scheduler_warm_restart.get_last_lr()[0] / cfg.scheduler.eta_min_factor, plateau=cfg.scheduler.plateau, ) runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, logdir=cfg.train.logdir, loaders=data_loaders, num_epochs=num_epochs, verbose=True, main_metric=cfg.train.main_metric, load_best_on_end=True, minimize_metric=False, check=cfg.check, fp16=dict(amp=cfg.amp), ) # Set new initial LR for optimizer after restart scheduler_warm_restart.step() print( f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}" ) # Find optimal threshold for dice score model.eval() best_th, dices = find_dice_threshold(model, data_loaders["valid"]) print("Best dice threshold", best_th, np.max(dices[1])) np.save(f"dices_{size}.npy", dices) else: print("Validation only") # Datasets size = cfg.data.sizes[-1] train_ds, valid_ds = D.get_train_valid_datasets_from_path( # path=(cwd / cfg.data.path), path=(cwd / f"data/hubmap-{size}x{size}/"), train_ids=cfg.data.train_ids, valid_ids=cfg.data.valid_ids, seed=cfg.seed, valid_split=cfg.data.valid_split, mean=cfg.data.mean, std=cfg.data.std, transforms=transforms, stats=stats, ) train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2)) valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2)) print( f"train: {len(train_ds)}; bs {train_bs}", f"valid: {len(valid_ds)}, bs {valid_bs}", ) # Data loaders data_loaders = D.get_data_loaders( train_ds=train_ds, valid_ds=valid_ds, train_bs=train_bs, valid_bs=valid_bs, num_workers=cfg.loader.num_workers, ) # Find optimal threshold for dice score model.eval() best_th, dices = find_dice_threshold(model, data_loaders["valid"]) print("Best dice threshold", best_th, np.max(dices[1])) np.save(f"dices_val.npy", dices) # # # Load best checkpoint # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth" # if checkpoint_path.exists(): # print(f"\nLoading checkpoint {str(checkpoint_path)}") # state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ # "model_state_dict" # ] # model.load_state_dict(state_dict) # del state_dict # model = model.to(device) # Load config for updating with threshold and metric # (otherwise loading do not work) cfg = OmegaConf.load(".hydra/config.yaml") cfg.threshold = float(best_th) # Evaluate on full-size image if valid_ids is non-empty df_train = pd.read_csv(cwd / "data/train.csv") df_train = { r["id"]: r["encoding"] for r in df_train.to_dict(orient="record") } dices = [] unique_ids = sorted( set( str(p).split("/")[-1].split("_")[0] for p in (cwd / cfg.data.path / "train").iterdir())) size = cfg.data.sizes[-1] scale = size / 1024 for image_id in cfg.data.valid_ids: image_name = unique_ids[image_id] print(f"\nValidate for {image_name}") rle_pred, shape = inference_one( image_path=(cwd / f"data/train/{image_name}.tiff"), target_path=Path("."), cfg=cfg, model=model, scale_factor=scale, tile_size=cfg.data.tile_size, tile_step=cfg.data.tile_step, threshold=best_th, save_raw=False, tta_mode=None, weight="pyramid", device=device, filter_crops="tissue", stats=stats, ) print("Predict", shape) pred = rle_decode(rle_pred["predicted"], shape) mask = rle_decode(df_train[image_name], shape) assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}" assert pred.shape == shape, f"pred {pred.shape}, expected {shape}" dices.append( dice( torch.from_numpy(pred).type(torch.uint8), torch.from_numpy(mask).type(torch.uint8), threshold=None, activation="none", )) print("Full image dice:", np.mean(dices)) OmegaConf.save(cfg, ".hydra/config.yaml") return