Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate",
                        type=float,
                        default=0,
                        help="Change of obliteration")
    parser.add_argument("-nid",
                        "--negative-image-dir",
                        type=str,
                        default=None,
                        help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-dd2",
                        "--data-dir-istego",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ISTEGO100K"))
    parser.add_argument("-m", "--model", type=str, default="resnet34", help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    parser.add_argument("-es",
                        "--early-stopping",
                        type=int,
                        default=None,
                        help="Maximum number of epochs without improvement")
    parser.add_argument("-fe",
                        "--freeze-encoder",
                        type=int,
                        default=0,
                        help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")

    parser.add_argument(
        "-l",
        "--modification-flag-loss",
        type=str,
        default=None,
        action="append",
        nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss",
        type=str,
        default=None,
        action="append",
        nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="safe",
                        type=str,
                        help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--size", default=None, type=int)
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument(
        "--fine-tune",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    args.is_master = args.local_rank == 0
    args.distributed = False
    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        # torch.distributed.init_process_group(backend="nccl", init_method="env://")
        print("Initialized init_process_group", args.local_rank)

    set_manual_seed(args.seed)

    assert (args.modification_flag_loss or args.modification_type_loss
            or args.embedding_loss), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss

    data_dir = args.data_dir
    data_dir_istego = args.data_dir_istego
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name: str = args.model
    optimizer_name = args.optimizer
    image_size = (args.size, args.size) if args.size is not None else (512,
                                                                       512)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir

    distributed_params = {"rank": args.local_rank, "syncbn": True}
    if fp16:
        distributed_params["opt_level"] = "O1"

    # Compute batch size for validation
    valid_batch_size = train_batch_size
    run_train = num_epochs > 0

    model: nn.Module = get_model(model_name, dropout=dropout)
    required_features = model.required_features

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transferring weights from model checkpoint",
              transfer_checkpoint)
        checkpoint = torch.load(transfer_checkpoint, map_location="cpu")
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    model = model.cuda()

    if freeze_bn:
        from pytorch_toolbelt.optimization.functional import freeze_model

        freeze_model(model, freeze_bn=True)
        print("Freezing bn params")

    main_metric = "loss"
    main_metric_minimize = True
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}_istego100k_local_rank_{args.local_rank}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = []

    if show:
        default_callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric="loss",
                                     minimize=True)
        ]

    if run_train:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            image_size=image_size,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        extra_train_ds = get_istego100k_train(data_dir_istego,
                                              fold=fold,
                                              features=required_features,
                                              output_size="random_crop")
        train_ds = train_ds + extra_train_ds

        if negative_image_dir:
            negatives_ds = get_negatives_ds(negative_image_dir,
                                            fold=fold,
                                            local_rank=args.local_rank,
                                            features=required_features,
                                            max_images=25000)
            train_ds = train_ds + negatives_ds
            train_sampler = None  # TODO: Add proper support of sampler
            print("Adding", len(negatives_ds),
                  "negative samples to training set")

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            num_epochs=num_epochs,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (default_callbacks + loss_callbacks + [
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": image_size[0],
                    "weight_decay": weight_decay,
                }),
        ])

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=False,
            sampler=DistributedSampler(train_ds, args.world_size,
                                       args.local_rank),
        )

        loaders["valid"] = DataLoader(
            valid_ds,
            batch_size=valid_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=False,
            shuffle=False,
            sampler=DistributedSampler(valid_ds,
                                       args.world_size,
                                       args.local_rank,
                                       shuffle=False),
        )

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches",
              len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches",
              len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("Distributed")
        print("  World size  :", args.world_size)
        print("  Local rank  :", args.local_rank)
        print("  Is master   :", args.is_master)

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate=learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

        # Restore state of best model
        clean_checkpoint(best_checkpoint, model_checkpoint)
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()
Beispiel #2
0
def main():
    # Give no chance to randomness
    torch.manual_seed(0)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=str, nargs="+")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-b", "--batch-size", type=int, default=1)
    parser.add_argument("-w", "--workers", type=int, default=0)
    parser.add_argument("-d4", "--d4-tta", action="store_true")
    parser.add_argument("-hv", "--hv-tta", action="store_true")
    parser.add_argument("-f", "--force-recompute", action="store_true")
    parser.add_argument("-fp16", "--fp16", action="store_true")

    args = parser.parse_args()

    checkpoint_fnames = args.checkpoint
    data_dir = args.data_dir
    batch_size = args.batch_size
    workers = args.workers
    fp16 = args.fp16
    d4_tta = args.d4_tta
    force_recompute = args.force_recompute
    need_embedding = True

    outputs = [
        OUTPUT_PRED_MODIFICATION_FLAG, OUTPUT_PRED_MODIFICATION_TYPE,
        OUTPUT_PRED_EMBEDDING
    ]
    embedding_suffix = "_w_emb" if need_embedding else ""

    for checkpoint_fname in checkpoint_fnames:
        model, checkpoints, required_features = ensemble_from_checkpoints(
            [checkpoint_fname],
            strict=True,
            outputs=outputs,
            activation=None,
            tta=None,
            need_embedding=need_embedding)

        report_checkpoint(checkpoints[0])

        model = model.cuda()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.eval()

        if fp16:
            model = model.half()

        train_ds = get_train_except_holdout(data_dir,
                                            features=required_features)
        holdout_ds = get_holdout(data_dir, features=required_features)
        test_ds = get_test_dataset(data_dir, features=required_features)

        if d4_tta:
            model = wrap_model_with_tta(model,
                                        "d4",
                                        inputs=required_features,
                                        outputs=outputs).eval()
            tta_suffix = "_d4_tta"
        else:
            tta_suffix = ""

        # Train
        trn_predictions_csv = fs.change_extension(
            checkpoint_fname,
            f"_train_predictions{embedding_suffix}{tta_suffix}.pkl")
        if force_recompute or not os.path.exists(trn_predictions_csv):
            trn_predictions = compute_trn_predictions(model,
                                                      train_ds,
                                                      fp16=fp16,
                                                      batch_size=batch_size,
                                                      workers=workers)
            trn_predictions.to_pickle(trn_predictions_csv)

        # Holdout
        hld_predictions_csv = fs.change_extension(
            checkpoint_fname,
            f"_holdout_predictions{embedding_suffix}{tta_suffix}.pkl")
        if force_recompute or not os.path.exists(hld_predictions_csv):
            hld_predictions = compute_trn_predictions(model,
                                                      holdout_ds,
                                                      fp16=fp16,
                                                      batch_size=batch_size,
                                                      workers=workers)
            hld_predictions.to_pickle(hld_predictions_csv)

        # Test
        tst_predictions_csv = fs.change_extension(
            checkpoint_fname,
            f"_test_predictions{embedding_suffix}{tta_suffix}.pkl")
        if force_recompute or not os.path.exists(tst_predictions_csv):
            tst_predictions = compute_trn_predictions(model,
                                                      test_ds,
                                                      fp16=fp16,
                                                      batch_size=batch_size,
                                                      workers=workers)
            tst_predictions.to_pickle(tst_predictions_csv)
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="unet", help="")
    parser.add_argument("-dd", "--data-dir", type=str, default=None, required=True, help="Data dir")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        required=True,
        help="Checkpoint filename to use as initial model weights",
    )
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch size for inference")
    parser.add_argument("-tta", "--tta", default=None, type=str, help="Type of TTA to use [fliplr, d4]")
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(checkpoint_file)
    out_dir = os.path.join(run_dir, "submit")
    os.makedirs(out_dir, exist_ok=True)

    model, checkpoint = model_from_checkpoint(checkpoint_file, strict=False)
    threshold = checkpoint["epoch_metrics"].get("valid_optimized_jaccard/threshold", 0.5)
    print(report_checkpoint(checkpoint))
    print("Using threshold", threshold)

    model = nn.Sequential(PickModelOutput(model, OUTPUT_MASK_KEY), nn.Sigmoid())

    if args.tta == "fliplr":
        model = TTAWrapper(model, fliplr_image2mask)
    elif args.tta == "d4":
        model = TTAWrapper(model, d4_image2mask)
    elif args.tta == "ms-d2":
        model = TTAWrapper(model, fliplr_image2mask)
        model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128])
    elif args.tta == "ms-d4":
        model = TTAWrapper(model, d4_image2mask)
        model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128])
    elif args.tta == "ms":
        model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128])
    else:
        pass

    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model = model.eval()

    # mask = predict(model, read_inria_image("sample_color.jpg"), image_size=(512, 512), batch_size=args.batch_size * torch.cuda.device_count())
    # mask = ((mask > threshold) * 255).astype(np.uint8)
    # name = os.path.join(run_dir, "sample_color.jpg")
    # cv2.imwrite(name, mask)

    test_predictions_dir = os.path.join(out_dir, "test_predictions")
    test_predictions_dir_compressed = os.path.join(out_dir, "test_predictions_compressed")

    if args.tta is not None:
        test_predictions_dir += f"_{args.tta}"
        test_predictions_dir_compressed += f"_{args.tta}"

    os.makedirs(test_predictions_dir, exist_ok=True)
    os.makedirs(test_predictions_dir_compressed, exist_ok=True)

    test_images = find_in_dir(os.path.join(data_dir, "test", "images"))
    for fname in tqdm(test_images, total=len(test_images)):
        predicted_mask_fname = os.path.join(test_predictions_dir, os.path.basename(fname))

        if not os.path.isfile(predicted_mask_fname):
            image = read_inria_image(fname)
            mask = predict(model, image, image_size=(512, 512), batch_size=args.batch_size * torch.cuda.device_count())
            mask = ((mask > threshold) * 255).astype(np.uint8)
            cv2.imwrite(predicted_mask_fname, mask)

        name_compressed = os.path.join(test_predictions_dir_compressed, os.path.basename(fname))
        command = (
            "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 "
            + predicted_mask_fname
            + " "
            + name_compressed
        )
        subprocess.call(command, shell=True)
Beispiel #4
0
def main():
    # Give no chance to randomness
    torch.manual_seed(0)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=str, nargs="+")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-b", "--batch-size", type=int, default=1)
    parser.add_argument("-w", "--workers", type=int, default=0)
    parser.add_argument("-d4", "--d4-tta", action="store_true")
    parser.add_argument("-hv", "--hv-tta", action="store_true")
    parser.add_argument("-f", "--force-recompute", action="store_true")
    parser.add_argument("-oof", "--need-oof", action="store_true")

    args = parser.parse_args()

    checkpoint_fnames = args.checkpoint
    data_dir = args.data_dir
    batch_size = args.batch_size
    workers = args.workers

    d4_tta = args.d4_tta
    hv_tta = args.hv_tta
    force_recompute = args.force_recompute
    outputs = [OUTPUT_PRED_MODIFICATION_FLAG, OUTPUT_PRED_MODIFICATION_TYPE]

    for checkpoint_fname in checkpoint_fnames:
        model, checkpoints, required_features = ensemble_from_checkpoints(
            [checkpoint_fname], strict=True, outputs=outputs, activation=None, tta=None
        )

        report_checkpoint(checkpoints[0])

        model = model.cuda()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.eval()

        # Holdout
        variants = {
            "istego100k_test_same_center_crop": get_istego100k_test_same(
                data_dir, features=required_features, output_size="center_crop"
            ),
            "istego100k_test_same_full": get_istego100k_test_same(
                data_dir, features=required_features, output_size="full"
            ),
            "istego100k_test_other_center_crop": get_istego100k_test_other(
                data_dir, features=required_features, output_size="center_crop"
            ),
            "istego100k_test_other_full": get_istego100k_test_other(
                data_dir, features=required_features, output_size="full"
            ),
            "holdout": get_holdout("d:\datasets\ALASKA2", features=required_features),
        }

        for name, dataset in variants.items():
            print("Making predictions for ", name, len(dataset))

            predictions_csv = fs.change_extension(checkpoint_fname, f"_{name}_predictions.csv")
            if force_recompute or not os.path.exists(predictions_csv):
                holdout_predictions = compute_oof_predictions(
                    model, dataset, batch_size=batch_size // 4 if "full" in name else batch_size, workers=workers
                )
                holdout_predictions.to_csv(predictions_csv, index=False)
                holdout_predictions = pd.read_csv(predictions_csv)

                print(name)
                print(
                    "\tbAUC",
                    alaska_weighted_auc(
                        holdout_predictions[INPUT_TRUE_MODIFICATION_FLAG].values,
                        holdout_predictions[OUTPUT_PRED_MODIFICATION_FLAG].apply(sigmoid).values,
                    ),
                )

                print(
                    "\tcAUC",
                    alaska_weighted_auc(
                        holdout_predictions[INPUT_TRUE_MODIFICATION_FLAG].values,
                        holdout_predictions[OUTPUT_PRED_MODIFICATION_TYPE].apply(parse_classifier_probas).values,
                    ),
                )
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument(
        "-l2",
        "--criterion2",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 2 mask",
    )
    parser.add_argument(
        "-l4",
        "--criterion4",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 4 mask",
    )
    parser.add_argument(
        "-l8",
        "--criterion8",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 8 mask",
    )
    parser.add_argument(
        "-l16",
        "--criterion16",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 16 mask",
    )

    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=None,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}
        if args.fp16:
            distributed_params["amp"] = True
    else:
        if args.fp16:
            distributed_params = {}
            distributed_params["amp"] = True
        else:
            distributed_params = False

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    criterions2 = args.criterion2
    criterions4 = args.criterion4
    criterions8 = args.criterion8
    criterions16 = args.criterion16

    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if any([criterions2, criterions4, criterions8, criterions16]):
        custom_model_kwargs["need_supervision_masks"] = True
        print("Enabling supervision masks")

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "optimized_jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"),
        JaccardMetricPerImageWithOptimalThreshold(input_key=INPUT_MASK_KEY,
                                                  output_key=OUTPUT_MASK_KEY,
                                                  prefix="optimized_jaccard"),
    ]

    if is_master:

        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="optimized_jaccard",
                                         target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout),
                }),
        ]

        if show:
            visualize_inria_predictions = partial(
                draw_inria_predictions,
                inputs_to_labels=lambda x: x.ge(0.5).squeeze(1),
                outputs_to_labels=lambda x: x.float().sigmoid().ge(0.5).
                squeeze(1),
                image_key=INPUT_IMAGE_KEY,
                image_id_key=INPUT_IMAGE_ID_KEY,
                targets_key=INPUT_MASK_KEY,
                outputs_key=OUTPUT_MASK_KEY,
                max_images=16,
            )
            default_callbacks += [
                ShowPolarBatchesCallback(visualize_inria_predictions,
                                         metric="accuracy",
                                         minimize=False),
                ShowPolarBatchesCallback(visualize_inria_predictions,
                                         metric="loss",
                                         minimize=True),
            ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            if args.distributed:
                label_sampler = DistributedSampler(unlabeled_label,
                                                   args.world_size,
                                                   args.local_rank,
                                                   shuffle=False)
            else:
                label_sampler = None

            loaders["infer"] = DataLoader(
                unlabeled_label,
                batch_size=batch_size // 2,
                num_workers=num_workers,
                pin_memory=True,
                sampler=label_sampler,
                drop_last=False,
            )

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="infer",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(train_sampler,
                                                          args.world_size,
                                                          args.local_rank,
                                                          shuffle=True)
            else:
                train_sampler = DistributedSampler(train_ds,
                                                   args.world_size,
                                                   args.local_rank,
                                                   shuffle=True)
            valid_sampler = DistributedSampler(valid_ds,
                                               args.world_size,
                                               args.local_rank,
                                               shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      sampler=valid_sampler)

        if model_name in {"U2NETP", "U2NET"}:
            dsv_criterions = criterions
        else:
            dsv_criterions = None

        loss_callbacks, loss_criterions = get_criterions(
            criterions=criterions,
            criterions_stride1_dsv1=dsv_criterions,
            criterions_stride1_dsv2=dsv_criterions,
            criterions_stride1_dsv3=dsv_criterions,
            criterions_stride1_dsv4=dsv_criterions,
            criterions_stride1_dsv5=dsv_criterions,
            criterions_stride1_dsv6=dsv_criterions,
            criterions_stride2=criterions2,
            criterions_stride4=criterions4,
            criterions_stride8=criterions8,
            criterions_stride16=criterions16,
        )
        callbacks += loss_callbacks

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        log_dir = os.path.join("runs", checkpoint_prefix)

        if is_master:
            os.makedirs(log_dir, exist_ok=False)
            config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
            with open(config_fname, "w") as f:
                train_session_args = vars(args)
                f.write(json.dumps(train_session_args, indent=2))

            print("Train session    :", checkpoint_prefix)
            print("  FP16 mode      :", fp16)
            print("  Fast mode      :", args.fast)
            print("  Train mode     :", train_mode)
            print("  Epochs         :", num_epochs)
            print("  Workers        :", num_workers)
            print("  Data dir       :", data_dir)
            print("  Log dir        :", log_dir)
            print("  Augmentations  :", augmentations)
            print("  Train size     :", "batches", len(loaders["train"]),
                  "dataset", len(train_ds))
            print("  Valid size     :", "batches", len(loaders["valid"]),
                  "dataset", len(valid_ds))
            print("Model            :", model_name)
            print("  Parameters     :", count_parameters(model))
            print("  Image size     :", image_size)
            print("Optimizer        :", optimizer_name)
            print("  Learning rate  :", learning_rate)
            print("  Batch size     :", batch_size)
            print("  Criterion      :", criterions)
            print("  Use weight mask:", need_weight_mask)
            if args.distributed:
                print("Distributed")
                print("  World size     :", args.world_size)
                print("  Local rank     :", args.local_rank)
                print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                                  output_key=None,
                                  device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=loss_criterions,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        if is_master:
            best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                           "best.pth")

            model_checkpoint = os.path.join(log_dir,
                                            f"{checkpoint_prefix}.pth")
            clean_checkpoint(best_checkpoint, model_checkpoint)

            unpack_checkpoint(torch.load(model_checkpoint), model=model)

            mask = predict(model,
                           read_inria_image("sample_color.jpg"),
                           image_size=image_size,
                           batch_size=args.batch_size)
            mask = ((mask > 0) * 255).astype(np.uint8)
            name = os.path.join(log_dir, "sample_color.jpg")
            cv2.imwrite(name, mask)