def main(): parser = argparse.ArgumentParser() ########################################################################################### # Distributed-training related stuff parser.add_argument("--local_rank", type=int, default=0) ########################################################################################### parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration") parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument("--cache", action="store_true") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-dd2", "--data-dir-istego", type=str, default=os.environ.get("KAGGLE_2020_ISTEGO100K")) parser.add_argument("-m", "--model", type=str, default="resnet34", help="") parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") parser.add_argument("-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement") parser.add_argument("-fe", "--freeze-encoder", type=int, default=0, help="Freeze encoder parameters for N epochs") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument( "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument( "--modification-type-loss", type=str, default=None, action="append", nargs="+" # [["ce", 1.0]], ) parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+") # [["ce", 1.0]], parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--mixup", action="store_true") parser.add_argument("--cutmix", action="store_true") parser.add_argument("--tsa", action="store_true") parser.add_argument("--size", default=None, type=int) parser.add_argument("--fold", default=None, type=int) parser.add_argument("-s", "--scheduler", default=None, type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=0.0, type=float, help="Dropout before head layer") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument( "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--balance", action="store_true") parser.add_argument("--freeze-bn", action="store_true") args = parser.parse_args() args.is_master = args.local_rank == 0 args.distributed = False if "WORLD_SIZE" in os.environ: args.distributed = int(os.environ["WORLD_SIZE"]) > 1 args.world_size = int(os.environ["WORLD_SIZE"]) # args.world_size = torch.distributed.get_world_size() print("Initializing init_process_group", args.local_rank) torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group(backend="nccl", init_method="env://") print("Initialized init_process_group", args.local_rank) set_manual_seed(args.seed) assert (args.modification_flag_loss or args.modification_type_loss or args.embedding_loss), "At least one of losses must be set" modification_flag_loss = args.modification_flag_loss modification_type_loss = args.modification_type_loss embedding_loss = args.embedding_loss feature_maps_loss = args.feature_maps_loss data_dir = args.data_dir data_dir_istego = args.data_dir_istego cache = args.cache num_workers = args.workers num_epochs = args.epochs learning_rate = args.learning_rate model_name: str = args.model optimizer_name = args.optimizer image_size = (args.size, args.size) if args.size is not None else (512, 512) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout verbose = args.verbose warmup = args.warmup show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay fold = args.fold balance = args.balance freeze_bn = args.freeze_bn train_batch_size = args.batch_size mixup = args.mixup cutmix = args.cutmix tsa = args.tsa obliterate_p = args.obliterate negative_image_dir = args.negative_image_dir distributed_params = {"rank": args.local_rank, "syncbn": True} if fp16: distributed_params["opt_level"] = "O1" # Compute batch size for validation valid_batch_size = train_batch_size run_train = num_epochs > 0 model: nn.Module = get_model(model_name, dropout=dropout) required_features = model.required_features if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transferring weights from model checkpoint", transfer_checkpoint) checkpoint = torch.load(transfer_checkpoint, map_location="cpu") pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) model = model.cuda() if freeze_bn: from pytorch_toolbelt.optimization.functional import freeze_model freeze_model(model, freeze_bn=True) print("Freezing bn params") main_metric = "loss" main_metric_minimize = True cmd_args = vars(args) current_time = datetime.now().strftime("%b%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}_istego100k_local_rank_{args.local_rank}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if mixup: checkpoint_prefix += "_mixup" if cutmix: checkpoint_prefix += "_cutmix" if experiment is not None: checkpoint_prefix = experiment log_dir = os.path.join("runs", checkpoint_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) default_callbacks = [] if show: default_callbacks += [ ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True) ] if run_train: train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, balance=balance, fast=fast, fold=fold, features=required_features, obliterate_p=obliterate_p, ) extra_train_ds = get_istego100k_train(data_dir_istego, fold=fold, features=required_features, output_size="random_crop") train_ds = train_ds + extra_train_ds if negative_image_dir: negatives_ds = get_negatives_ds(negative_image_dir, fold=fold, local_rank=args.local_rank, features=required_features, max_images=25000) train_ds = train_ds + negatives_ds train_sampler = None # TODO: Add proper support of sampler print("Adding", len(negatives_ds), "negative samples to training set") criterions_dict, loss_callbacks = get_criterions( modification_flag=modification_flag_loss, modification_type=modification_type_loss, embedding_loss=embedding_loss, feature_maps_loss=feature_maps_loss, num_epochs=num_epochs, mixup=mixup, cutmix=cutmix, tsa=tsa, ) callbacks = (default_callbacks + loss_callbacks + [ OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": image_size[0], "weight_decay": weight_decay, }), ]) loaders = collections.OrderedDict() loaders["train"] = DataLoader( train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=False, sampler=DistributedSampler(train_ds, args.world_size, args.local_rank), ) loaders["valid"] = DataLoader( valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True, drop_last=False, shuffle=False, sampler=DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False), ) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Cache :", cache) print("Data ") print(" Augmentations :", augmentations) print(" Obliterate (%) :", obliterate_p) print(" Negative images:", negative_image_dir) print(" Train size :", len(loaders["train"]), "batches", len(train_ds), "samples") print(" Valid size :", len(loaders["valid"]), "batches", len(valid_ds), "samples") print(" Image size :", image_size) print(" Balance :", balance) print(" Mixup :", mixup) print(" CutMix :", cutmix) print(" TSA :", tsa) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Dropout :", dropout) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Weight decay :", weight_decay) print(" Scheduler :", scheduler_name) print(" Batch sizes :", train_batch_size, valid_batch_size) print("Losses ") print(" Flag :", modification_flag_loss) print(" Type :", modification_type_loss) print(" Embedding :", embedding_loss) print(" Feature maps :", feature_maps_loss) print("Distributed") print(" World size :", args.world_size) print(" Local rank :", args.local_rank) print(" Is master :", args.is_master) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, CyclicLR): callbacks += [SchedulerCallback(mode="batch")] # model training runner = SupervisedRunner(input_key=required_features, output_key=None) runner.train( fp16=distributed_params, model=model, criterion=criterions_dict, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=main_metric_minimize, checkpoint_data={"cmd_args": vars(args)}, ) del optimizer, loaders, runner, callbacks best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") # Restore state of best model clean_checkpoint(best_checkpoint, model_checkpoint) # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model) torch.cuda.empty_cache() gc.collect()
def main(): # Give no chance to randomness torch.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False parser = argparse.ArgumentParser() parser.add_argument("checkpoint", type=str, nargs="+") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-b", "--batch-size", type=int, default=1) parser.add_argument("-w", "--workers", type=int, default=0) parser.add_argument("-d4", "--d4-tta", action="store_true") parser.add_argument("-hv", "--hv-tta", action="store_true") parser.add_argument("-f", "--force-recompute", action="store_true") parser.add_argument("-fp16", "--fp16", action="store_true") args = parser.parse_args() checkpoint_fnames = args.checkpoint data_dir = args.data_dir batch_size = args.batch_size workers = args.workers fp16 = args.fp16 d4_tta = args.d4_tta force_recompute = args.force_recompute need_embedding = True outputs = [ OUTPUT_PRED_MODIFICATION_FLAG, OUTPUT_PRED_MODIFICATION_TYPE, OUTPUT_PRED_EMBEDDING ] embedding_suffix = "_w_emb" if need_embedding else "" for checkpoint_fname in checkpoint_fnames: model, checkpoints, required_features = ensemble_from_checkpoints( [checkpoint_fname], strict=True, outputs=outputs, activation=None, tta=None, need_embedding=need_embedding) report_checkpoint(checkpoints[0]) model = model.cuda() if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model = model.eval() if fp16: model = model.half() train_ds = get_train_except_holdout(data_dir, features=required_features) holdout_ds = get_holdout(data_dir, features=required_features) test_ds = get_test_dataset(data_dir, features=required_features) if d4_tta: model = wrap_model_with_tta(model, "d4", inputs=required_features, outputs=outputs).eval() tta_suffix = "_d4_tta" else: tta_suffix = "" # Train trn_predictions_csv = fs.change_extension( checkpoint_fname, f"_train_predictions{embedding_suffix}{tta_suffix}.pkl") if force_recompute or not os.path.exists(trn_predictions_csv): trn_predictions = compute_trn_predictions(model, train_ds, fp16=fp16, batch_size=batch_size, workers=workers) trn_predictions.to_pickle(trn_predictions_csv) # Holdout hld_predictions_csv = fs.change_extension( checkpoint_fname, f"_holdout_predictions{embedding_suffix}{tta_suffix}.pkl") if force_recompute or not os.path.exists(hld_predictions_csv): hld_predictions = compute_trn_predictions(model, holdout_ds, fp16=fp16, batch_size=batch_size, workers=workers) hld_predictions.to_pickle(hld_predictions_csv) # Test tst_predictions_csv = fs.change_extension( checkpoint_fname, f"_test_predictions{embedding_suffix}{tta_suffix}.pkl") if force_recompute or not os.path.exists(tst_predictions_csv): tst_predictions = compute_trn_predictions(model, test_ds, fp16=fp16, batch_size=batch_size, workers=workers) tst_predictions.to_pickle(tst_predictions_csv)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="unet", help="") parser.add_argument("-dd", "--data-dir", type=str, default=None, required=True, help="Data dir") parser.add_argument( "-c", "--checkpoint", type=str, default=None, required=True, help="Checkpoint filename to use as initial model weights", ) parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch size for inference") parser.add_argument("-tta", "--tta", default=None, type=str, help="Type of TTA to use [fliplr, d4]") args = parser.parse_args() data_dir = args.data_dir checkpoint_file = auto_file(args.checkpoint) run_dir = os.path.dirname(checkpoint_file) out_dir = os.path.join(run_dir, "submit") os.makedirs(out_dir, exist_ok=True) model, checkpoint = model_from_checkpoint(checkpoint_file, strict=False) threshold = checkpoint["epoch_metrics"].get("valid_optimized_jaccard/threshold", 0.5) print(report_checkpoint(checkpoint)) print("Using threshold", threshold) model = nn.Sequential(PickModelOutput(model, OUTPUT_MASK_KEY), nn.Sigmoid()) if args.tta == "fliplr": model = TTAWrapper(model, fliplr_image2mask) elif args.tta == "d4": model = TTAWrapper(model, d4_image2mask) elif args.tta == "ms-d2": model = TTAWrapper(model, fliplr_image2mask) model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) elif args.tta == "ms-d4": model = TTAWrapper(model, d4_image2mask) model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) elif args.tta == "ms": model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) else: pass model = model.cuda() if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model = model.eval() # mask = predict(model, read_inria_image("sample_color.jpg"), image_size=(512, 512), batch_size=args.batch_size * torch.cuda.device_count()) # mask = ((mask > threshold) * 255).astype(np.uint8) # name = os.path.join(run_dir, "sample_color.jpg") # cv2.imwrite(name, mask) test_predictions_dir = os.path.join(out_dir, "test_predictions") test_predictions_dir_compressed = os.path.join(out_dir, "test_predictions_compressed") if args.tta is not None: test_predictions_dir += f"_{args.tta}" test_predictions_dir_compressed += f"_{args.tta}" os.makedirs(test_predictions_dir, exist_ok=True) os.makedirs(test_predictions_dir_compressed, exist_ok=True) test_images = find_in_dir(os.path.join(data_dir, "test", "images")) for fname in tqdm(test_images, total=len(test_images)): predicted_mask_fname = os.path.join(test_predictions_dir, os.path.basename(fname)) if not os.path.isfile(predicted_mask_fname): image = read_inria_image(fname) mask = predict(model, image, image_size=(512, 512), batch_size=args.batch_size * torch.cuda.device_count()) mask = ((mask > threshold) * 255).astype(np.uint8) cv2.imwrite(predicted_mask_fname, mask) name_compressed = os.path.join(test_predictions_dir_compressed, os.path.basename(fname)) command = ( "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 " + predicted_mask_fname + " " + name_compressed ) subprocess.call(command, shell=True)
def main(): # Give no chance to randomness torch.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False parser = argparse.ArgumentParser() parser.add_argument("checkpoint", type=str, nargs="+") parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2")) parser.add_argument("-b", "--batch-size", type=int, default=1) parser.add_argument("-w", "--workers", type=int, default=0) parser.add_argument("-d4", "--d4-tta", action="store_true") parser.add_argument("-hv", "--hv-tta", action="store_true") parser.add_argument("-f", "--force-recompute", action="store_true") parser.add_argument("-oof", "--need-oof", action="store_true") args = parser.parse_args() checkpoint_fnames = args.checkpoint data_dir = args.data_dir batch_size = args.batch_size workers = args.workers d4_tta = args.d4_tta hv_tta = args.hv_tta force_recompute = args.force_recompute outputs = [OUTPUT_PRED_MODIFICATION_FLAG, OUTPUT_PRED_MODIFICATION_TYPE] for checkpoint_fname in checkpoint_fnames: model, checkpoints, required_features = ensemble_from_checkpoints( [checkpoint_fname], strict=True, outputs=outputs, activation=None, tta=None ) report_checkpoint(checkpoints[0]) model = model.cuda() if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model = model.eval() # Holdout variants = { "istego100k_test_same_center_crop": get_istego100k_test_same( data_dir, features=required_features, output_size="center_crop" ), "istego100k_test_same_full": get_istego100k_test_same( data_dir, features=required_features, output_size="full" ), "istego100k_test_other_center_crop": get_istego100k_test_other( data_dir, features=required_features, output_size="center_crop" ), "istego100k_test_other_full": get_istego100k_test_other( data_dir, features=required_features, output_size="full" ), "holdout": get_holdout("d:\datasets\ALASKA2", features=required_features), } for name, dataset in variants.items(): print("Making predictions for ", name, len(dataset)) predictions_csv = fs.change_extension(checkpoint_fname, f"_{name}_predictions.csv") if force_recompute or not os.path.exists(predictions_csv): holdout_predictions = compute_oof_predictions( model, dataset, batch_size=batch_size // 4 if "full" in name else batch_size, workers=workers ) holdout_predictions.to_csv(predictions_csv, index=False) holdout_predictions = pd.read_csv(predictions_csv) print(name) print( "\tbAUC", alaska_weighted_auc( holdout_predictions[INPUT_TRUE_MODIFICATION_FLAG].values, holdout_predictions[OUTPUT_PRED_MODIFICATION_FLAG].apply(sigmoid).values, ), ) print( "\tcAUC", alaska_weighted_auc( holdout_predictions[INPUT_TRUE_MODIFICATION_FLAG].values, holdout_predictions[OUTPUT_PRED_MODIFICATION_TYPE].apply(parse_classifier_probas).values, ), )
def main(): parser = argparse.ArgumentParser() ########################################################################################### # Distributed-training related stuff parser.add_argument("--local_rank", type=int, default=0) ########################################################################################### parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--fast", action="store_true") parser.add_argument( "-dd", "--data-dir", type=str, help="Data directory for INRIA sattelite dataset", default=os.environ.get("INRIA_DATA_DIR"), ) parser.add_argument("-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset") parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') # parser.add_argument('-ft', '--fine-tune', action='store_true') parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") parser.add_argument( "-l2", "--criterion2", type=str, required=False, action="append", nargs="+", help="Criterion for stride 2 mask", ) parser.add_argument( "-l4", "--criterion4", type=str, required=False, action="append", nargs="+", help="Criterion for stride 4 mask", ) parser.add_argument( "-l8", "--criterion8", type=str, required=False, action="append", nargs="+", help="Criterion for stride 8 mask", ) parser.add_argument( "-l16", "--criterion16", type=str, required=False, action="append", nargs="+", help="Criterion for stride 16 mask", ) parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") parser.add_argument( "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights") parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") parser.add_argument("--run-mode", default="fit_predict", type=str, help="") parser.add_argument("--transfer", default=None, type=str, help="") parser.add_argument("--fp16", action="store_true") parser.add_argument("--size", default=512, type=int) parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") parser.add_argument("-x", "--experiment", default=None, type=str, help="") parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") parser.add_argument("--opl", action="store_true") parser.add_argument( "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters") parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") parser.add_argument("--show", action="store_true") parser.add_argument("--dsv", action="store_true") args = parser.parse_args() args.is_master = args.local_rank == 0 args.distributed = False fp16 = args.fp16 if "WORLD_SIZE" in os.environ: args.distributed = int(os.environ["WORLD_SIZE"]) > 1 args.world_size = int(os.environ["WORLD_SIZE"]) # args.world_size = torch.distributed.get_world_size() print("Initializing init_process_group", args.local_rank) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl") print("Initialized init_process_group", args.local_rank) is_master = args.is_master | (not args.distributed) if args.distributed: distributed_params = {"rank": args.local_rank, "syncbn": True} if args.fp16: distributed_params["amp"] = True else: if args.fp16: distributed_params = {} distributed_params["amp"] = True else: distributed_params = False set_manual_seed(args.seed + args.local_rank) catalyst.utils.set_global_seed(args.seed + args.local_rank) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False data_dir = args.data_dir if data_dir is None: raise ValueError("--data-dir must be set") num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate model_name = args.model optimizer_name = args.optimizer image_size = args.size, args.size fast = args.fast augmentations = args.augmentations train_mode = args.train_mode scheduler_name = args.scheduler experiment = args.experiment dropout = args.dropout online_pseudolabeling = args.opl criterions = args.criterion criterions2 = args.criterion2 criterions4 = args.criterion4 criterions8 = args.criterion8 criterions16 = args.criterion16 verbose = args.verbose show = args.show accumulation_steps = args.accumulation_steps weight_decay = args.weight_decay extra_data_xview2 = args.data_dir_xview2 run_train = num_epochs > 0 need_weight_mask = any(c[0] == "wbce" for c in criterions) custom_model_kwargs = {} if dropout is not None: custom_model_kwargs["dropout"] = float(dropout) if any([criterions2, criterions4, criterions8, criterions16]): custom_model_kwargs["need_supervision_masks"] = True print("Enabling supervision masks") model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint["model_state_dict"] transfer_weights(model, pretrained_dict) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) print("Loaded model weights from:", args.checkpoint) report_checkpoint(checkpoint) main_metric = "optimized_jaccard" current_time = datetime.now().strftime("%y%m%d_%H_%M") checkpoint_prefix = f"{current_time}_{args.model}" if fp16: checkpoint_prefix += "_fp16" if fast: checkpoint_prefix += "_fast" if online_pseudolabeling: checkpoint_prefix += "_opl" if extra_data_xview2: checkpoint_prefix += "_with_xview2" if experiment is not None: checkpoint_prefix = experiment default_callbacks = [ PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), JaccardMetricPerImageWithOptimalThreshold(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard"), ] if is_master: default_callbacks += [ BestMetricCheckpointCallback(target_metric="optimized_jaccard", target_metric_minimize=False), HyperParametersCallback( hparam_dict={ "model": model_name, "scheduler": scheduler_name, "optimizer": optimizer_name, "augmentations": augmentations, "size": args.size, "weight_decay": weight_decay, "epochs": num_epochs, "dropout": None if dropout is None else float(dropout), }), ] if show: visualize_inria_predictions = partial( draw_inria_predictions, inputs_to_labels=lambda x: x.ge(0.5).squeeze(1), outputs_to_labels=lambda x: x.float().sigmoid().ge(0.5). squeeze(1), image_key=INPUT_IMAGE_KEY, image_id_key=INPUT_IMAGE_ID_KEY, targets_key=INPUT_MASK_KEY, outputs_key=OUTPUT_MASK_KEY, max_images=16, ) default_callbacks += [ ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False), ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True), ] train_ds, valid_ds, train_sampler = get_datasets( data_dir=data_dir, image_size=image_size, augmentation=augmentations, train_mode=train_mode, buildings_only=(train_mode == "tiles"), fast=fast, need_weight_mask=need_weight_mask, ) if extra_data_xview2 is not None: extra_train_ds, _ = get_xview2_extra_dataset( extra_data_xview2, image_size=image_size, augmentation=augmentations, fast=fast, need_weight_mask=need_weight_mask, ) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) train_ds = train_ds + extra_train_ds print("Using extra data from xView2 with", len(extra_train_ds), "samples") if run_train: loaders = collections.OrderedDict() callbacks = default_callbacks.copy() criterions_dict = {} losses = [] ignore_index = None if online_pseudolabeling: ignore_index = UNLABELED_SAMPLE unlabeled_label = get_pseudolabeling_dataset(data_dir, include_masks=False, augmentation=None, image_size=image_size) unlabeled_train = get_pseudolabeling_dataset( data_dir, include_masks=True, augmentation=augmentations, image_size=image_size) if args.distributed: label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False) else: label_sampler = None loaders["infer"] = DataLoader( unlabeled_label, batch_size=batch_size // 2, num_workers=num_workers, pin_memory=True, sampler=label_sampler, drop_last=False, ) if train_sampler is not None: num_samples = 2 * train_sampler.num_samples else: num_samples = 2 * len(train_ds) weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) train_ds = train_ds + unlabeled_train callbacks += [ BCEOnlinePseudolabelingCallback2d( unlabeled_train, pseudolabel_loader="infer", prob_threshold=0.7, output_key=OUTPUT_MASK_KEY, unlabeled_class=UNLABELED_SAMPLE, label_frequency=5, ) ] print("Using online pseudolabeling with ", len(unlabeled_label), "samples") valid_sampler = None if args.distributed: if train_sampler is not None: train_sampler = DistributedSamplerWrapper(train_sampler, args.world_size, args.local_rank, shuffle=True) else: train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True) valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False) loaders["train"] = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=train_sampler is None, sampler=train_sampler, ) loaders["valid"] = DataLoader(valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler) if model_name in {"U2NETP", "U2NET"}: dsv_criterions = criterions else: dsv_criterions = None loss_callbacks, loss_criterions = get_criterions( criterions=criterions, criterions_stride1_dsv1=dsv_criterions, criterions_stride1_dsv2=dsv_criterions, criterions_stride1_dsv3=dsv_criterions, criterions_stride1_dsv4=dsv_criterions, criterions_stride1_dsv5=dsv_criterions, criterions_stride1_dsv6=dsv_criterions, criterions_stride2=criterions2, criterions_stride4=criterions4, criterions_stride8=criterions8, criterions_stride16=criterions16, ) callbacks += loss_callbacks optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])) if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): callbacks += [SchedulerCallback(mode="batch")] log_dir = os.path.join("runs", checkpoint_prefix) if is_master: os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") with open(config_fname, "w") as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) print("Train session :", checkpoint_prefix) print(" FP16 mode :", fp16) print(" Fast mode :", args.fast) print(" Train mode :", train_mode) print(" Epochs :", num_epochs) print(" Workers :", num_workers) print(" Data dir :", data_dir) print(" Log dir :", log_dir) print(" Augmentations :", augmentations) print(" Train size :", "batches", len(loaders["train"]), "dataset", len(train_ds)) print(" Valid size :", "batches", len(loaders["valid"]), "dataset", len(valid_ds)) print("Model :", model_name) print(" Parameters :", count_parameters(model)) print(" Image size :", image_size) print("Optimizer :", optimizer_name) print(" Learning rate :", learning_rate) print(" Batch size :", batch_size) print(" Criterion :", criterions) print(" Use weight mask:", need_weight_mask) if args.distributed: print("Distributed") print(" World size :", args.world_size) print(" Local rank :", args.local_rank) print(" Is master :", args.is_master) # model training runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") runner.train( fp16=distributed_params, model=model, criterion=loss_criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, "main"), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}, ) # Training is finished. Let's run predictions using best checkpoint weights if is_master: best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") clean_checkpoint(best_checkpoint, model_checkpoint) unpack_checkpoint(torch.load(model_checkpoint), model=model) mask = predict(model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size) mask = ((mask > 0) * 255).astype(np.uint8) name = os.path.join(log_dir, "sample_color.jpg") cv2.imwrite(name, mask)