예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-m", "--model", type=str, default="resnet34", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    freeze_encoder = args.freeze_encoder
    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name: str = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    fine_tune = args.fine_tune
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir
    warmup_batch_size = args.warmup_batch_size or args.batch_size

    # Compute batch size for validation
    valid_batch_size = train_batch_size
    run_train = num_epochs > 0

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if embedding_loss is not None:
        custom_model_kwargs["need_embedding"] = True

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()
    required_features = model.required_features

    if mask_loss is not None:
        required_features.append(INPUT_TRUE_MODIFICATION_MASK)

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transferring weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        from pytorch_toolbelt.optimization.functional import freeze_model

        freeze_model(model, freeze_bn=True)
        print("Freezing bn params")

    main_metric = "loss"
    main_metric_minimize = True

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = []

    if show:
        default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)]

    # Pretrain/warmup
    if warmup:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=0,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            feature_maps_loss=feature_maps_loss,
            num_epochs=warmup,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=warmup_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True)

        if freeze_encoder:
            from pytorch_toolbelt.optimization.functional import freeze_model

            freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None)

        optimizer = get_optimizer(
            "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4
        )
        scheduler = None

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout, "(Non-default)" if dropout is not None else "")
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        # Restore state of best model
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        if negative_image_dir:
            negatives_ds = get_negatives_ds(
                negative_image_dir, fold=fold, features=required_features, max_images=16536
            )
            train_ds = train_ds + negatives_ds
            train_sampler = None  # TODO: Add proper support of sampler
            print("Adding", len(negatives_ds), "negative samples to training set")

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=num_epochs,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

        # Restore state of best model
        clean_checkpoint(best_checkpoint, model_checkpoint)
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if fine_tune:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation="light",
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=fine_tune,
            mixup=False,
            cutmix=False,
            tsa=False,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "finetune"),
            num_epochs=fine_tune,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth")

        clean_checkpoint(best_checkpoint, model_checkpoint)
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        del optimizer, loaders, runner, callbacks
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument(
        "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset"
    )
    parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")
    parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion")
    parser.add_argument(
        "-l2",
        "--criterion2",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 2 mask",
    )
    parser.add_argument(
        "-l4",
        "--criterion4",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 4 mask",
    )
    parser.add_argument(
        "-l8",
        "--criterion8",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 8 mask",
    )
    parser.add_argument(
        "-l16",
        "--criterion16",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 16 mask",
    )

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="hard", type=str, help="")
    parser.add_argument("-tm", "--train-mode", default="random", type=str, help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}
        if args.fp16:
            distributed_params["amp"] = True
    else:
        if args.fp16:
            distributed_params = {}
            distributed_params["amp"] = True
        else:
            distributed_params = False

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    criterions2 = args.criterion2
    criterions4 = args.criterion4
    criterions8 = args.criterion8
    criterions16 = args.criterion16

    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {"full_size_mask": False}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if any([criterions2, criterions4, criterions8, criterions16]):
        custom_model_kwargs["need_supervision_masks"] = True
        print("Enabling supervision masks")

    model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    default_callbacks = [
        JaccardMetricPerImage(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            prefix="jaccard",
            inputs_to_labels=depth2mask,
            outputs_to_labels=decode_depth_mask,
        ),
    ]

    if is_master:

        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout),
                }
            ),
        ]

        if show:
            visualize_inria_predictions = partial(
                draw_inria_predictions,
                image_key=INPUT_IMAGE_KEY,
                image_id_key=INPUT_IMAGE_ID_KEY,
                targets_key=INPUT_MASK_KEY,
                outputs_key=OUTPUT_MASK_KEY,
                inputs_to_labels=depth2mask,
                outputs_to_labels=decode_depth_mask,
                max_images=16,
            )
            default_callbacks += [
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False),
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True),
            ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
        make_mask_target_fn=mask_to_ce_target,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds), "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(
                data_dir, include_masks=False, augmentation=None, image_size=image_size
            )

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir, include_masks=True, augmentation=augmentations, image_size=image_size
            )

            if args.distributed:
                label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False)
            else:
                label_sampler = None

            loaders["infer"] = DataLoader(
                unlabeled_label,
                batch_size=batch_size // 2,
                num_workers=num_workers,
                pin_memory=True,
                sampler=label_sampler,
                drop_last=False,
            )

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="infer",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label), "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(
                    train_sampler, args.world_size, args.local_rank, shuffle=True
                )
            else:
                train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True)
            valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(
            valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler
        )

        loss_callbacks, loss_criterions = get_criterions(
            criterions, criterions2, criterions4, criterions8, criterions16
        )
        callbacks += loss_callbacks

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        log_dir = os.path.join("runs", checkpoint_prefix)

        if is_master:
            os.makedirs(log_dir, exist_ok=False)
            config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
            with open(config_fname, "w") as f:
                train_session_args = vars(args)
                f.write(json.dumps(train_session_args, indent=2))

            print("Train session    :", checkpoint_prefix)
            print("  FP16 mode      :", fp16)
            print("  Fast mode      :", args.fast)
            print("  Train mode     :", train_mode)
            print("  Epochs         :", num_epochs)
            print("  Workers        :", num_workers)
            print("  Data dir       :", data_dir)
            print("  Log dir        :", log_dir)
            print("  Augmentations  :", augmentations)
            print("  Train size     :", "batches", len(loaders["train"]), "dataset", len(train_ds))
            print("  Valid size     :", "batches", len(loaders["valid"]), "dataset", len(valid_ds))
            print("Model            :", model_name)
            print("  Parameters     :", count_parameters(model))
            print("  Image size     :", image_size)
            print("Optimizer        :", optimizer_name)
            print("  Learning rate  :", learning_rate)
            print("  Batch size     :", batch_size)
            print("  Criterion      :", criterions)
            print("  Use weight mask:", need_weight_mask)
            if args.distributed:
                print("Distributed")
                print("  World size     :", args.world_size)
                print("  Local rank     :", args.local_rank)
                print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=loss_criterions,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        if is_master:
            best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")

            model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
            clean_checkpoint(best_checkpoint, model_checkpoint)

            unpack_checkpoint(torch.load(model_checkpoint), model=model)

            mask = predict(
                model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size
            )
            mask = ((mask > 0) * 255).astype(np.uint8)
            name = os.path.join(log_dir, "sample_color.jpg")
            cv2.imwrite(name, mask)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument(
        "--disaster-type-loss",
        type=str,
        default=None,  # [["ce", 1.0]],
        action="append",
        nargs="+",
        help="Criterion for classifying disaster type",
    )
    parser.add_argument(
        "--damage-type-loss",
        type=str,
        default=None,  # [["bce", 1.0]],
        action="append",
        nargs="+",
        help=
        "Criterion for classifying presence of building with particular damage type",
    )

    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("--mask4",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 4")
    parser.add_argument("--mask8",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 8")
    parser.add_argument("--mask16",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 16")
    parser.add_argument("--mask32",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 32")
    parser.add_argument("--embedding", type=str, default=None)

    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="safe",
                        type=str,
                        help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("--fold", default=0, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("-pl", "--pseudolabeling", type=str, required=True)
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--only-buildings", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")
    parser.add_argument("--crops",
                        action="store_true",
                        help="Train on random crops")
    parser.add_argument("--post-transform", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    segmentation_losses = args.criterion
    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    only_buildings = args.only_buildings
    freeze_bn = args.freeze_bn
    train_on_crops = args.crops
    enable_post_image_transform = args.post_transform
    disaster_type_loss = args.disaster_type_loss
    train_batch_size = args.batch_size
    embedding_criterion = args.embedding
    damage_type_loss = args.damage_type_loss
    pseudolabels_dir = args.pseudolabeling

    # Compute batch size for validaion
    if train_on_crops:
        valid_batch_size = max(1,
                               (train_batch_size *
                                (image_size[0] * image_size[1])) // (1024**2))
    else:
        valid_batch_size = train_batch_size

    run_train = num_epochs > 0

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        torch_utils.freeze_bn(model)
        print("Freezing bn params")

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None)
    main_metric = "weighted_f1"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if pseudolabels_dir:
        checkpoint_prefix += "_pseudo"

    if train_on_crops:
        checkpoint_prefix += "_crops"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        CompetitionMetricCallback(input_key=INPUT_MASK_KEY,
                                  output_key=OUTPUT_MASK_KEY,
                                  prefix="weighted_f1"),
        ConfusionMatrixCallback(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            class_names=[
                "land", "no_damage", "minor_damage", "major_damage",
                "destroyed"
            ],
            ignore_index=UNLABELED_SAMPLE,
        ),
    ]

    if show:
        default_callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric=main_metric + "_batch",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        fast=fast,
        fold=fold,
        balance=balance,
        only_buildings=only_buildings,
        train_on_crops=train_on_crops,
        crops_multiplication_factor=1,
        enable_post_image_transform=enable_post_image_transform,
    )

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        unlabeled_train = get_pseudolabeling_dataset(
            data_dir,
            include_masks=True,
            image_size=image_size,
            augmentation="medium_nmd",
            train_on_crops=train_on_crops,
            enable_post_image_transform=enable_post_image_transform,
            pseudolabels_dir=pseudolabels_dir,
        )

        train_ds = train_ds + unlabeled_train

        print("Using online pseudolabeling with ", len(unlabeled_train),
              "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=valid_batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for criterion in segmentation_losses:
            if isinstance(criterion, (list, tuple)) and len(criterion) == 2:
                loss_name, loss_weight = criterion
            else:
                loss_name, loss_weight = criterion[0], 1.0

            cd, criterion, criterion_name = get_criterion_callback(
                loss_name,
                prefix="segmentation",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_MASK_KEY,
                loss_weight=float(loss_weight),
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight)

        if args.mask4 is not None:
            for criterion in args.mask4:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask4",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_4_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight)

        if args.mask8 is not None:
            for criterion in args.mask8:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask8",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_8_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight)

        if args.mask16 is not None:
            for criterion in args.mask16:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask16",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_16_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight)

        if args.mask32 is not None:
            for criterion in args.mask32:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask32",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_32_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight)

        if disaster_type_loss is not None:
            callbacks += [
                ConfusionMatrixCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    class_names=DISASTER_TYPES,
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                    prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix",
                ),
                AccuracyCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    prefix=f"{DISASTER_TYPE_KEY}/accuracy",
                    activation="Softmax",
                ),
            ]

            for criterion in disaster_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DISASTER_TYPE_KEY,
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    loss_weight=float(loss_weight),
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if damage_type_loss is not None:
            callbacks += [
                # MultilabelConfusionMatrixCallback(
                #     input_key=DAMAGE_TYPE_KEY,
                #     output_key=DAMAGE_TYPE_KEY,
                #     class_names=DAMAGE_TYPES,
                #     prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix",
                # ),
                AccuracyCallback(
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    prefix=f"{DAMAGE_TYPE_KEY}/accuracy",
                    activation="Sigmoid",
                    threshold=0.5,
                )
            ]

            for criterion in damage_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DAMAGE_TYPE_KEY,
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if embedding_criterion is not None:
            cd, criterion, criterion_name = get_criterion_callback(
                embedding_criterion,
                prefix="embedding",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_EMBEDDING_KEY,
                loss_weight=1.0,
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("Data             ")
        print("  Augmentations  :", augmentations)
        print("  Train size     :", len(loaders["train"]), len(train_ds))
        print("  Valid size     :", len(loaders["valid"]), len(valid_ds))
        print("  Image size     :", image_size)
        print("  Train on crops :", train_on_crops)
        print("  Balance        :", balance)
        print("  Buildings only :", only_buildings)
        print("  Post transform :", enable_post_image_transform)
        print("  Pseudolabels   :", pseudolabels_dir)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("  Criterion      :", segmentation_losses)
        print("  Damage type    :", damage_type_loss)
        print("  Disaster type  :", disaster_type_loss)
        print(" Embedding      :", embedding_criterion)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "opl"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        del optimizer, loaders
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    use_dsv = args.dsv
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                              output_key=None,
                              device="cuda")
    main_metric = "optimized_jaccard"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        JaccardMetricPerImage(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY,
                              prefix="jaccard"),
        OptimalThreshold(input_key=INPUT_MASK_KEY,
                         output_key=OUTPUT_MASK_KEY,
                         prefix="optimized_jaccard"),
        # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid),
    ]

    if show:
        visualize_inria_predictions = partial(
            draw_inria_predictions,
            image_key=INPUT_IMAGE_KEY,
            image_id_key=INPUT_IMAGE_ID_KEY,
            targets_key=INPUT_MASK_KEY,
            outputs_key=OUTPUT_MASK_KEY,
        )
        default_callbacks += [
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="accuracy",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    # Pretrain/warmup
    if warmup:
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []
        ignore_index = None

        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        parameters = get_lr_decay_parameters(model.named_parameters(),
                                             learning_rate, {"encoder": 0.1})
        optimizer = get_optimizer("RAdam",
                                  parameters,
                                  learning_rate=learning_rate * 0.1)

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=None,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        del optimizer, loaders

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                       "best.pth")
        model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                        f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            loaders["label"] = DataLoader(unlabeled_label,
                                          batch_size=batch_size // 2,
                                          num_workers=num_workers,
                                          pin_memory=True)

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="label",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        if use_dsv:
            print("Using DSV")
            criterions = "dsv"
            dsv_loss_name = "soft_bce"

            criterions_dict[criterions] = AdaptiveMaskLoss2d(
                get_loss(dsv_loss_name, ignore_index=ignore_index))

            for i, dsv_input in enumerate([
                    OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY,
                    OUTPUT_MASK_32_KEY
            ]):
                criterion_callback = CriterionCallback(
                    prefix="seg_loss_dsv/" + dsv_input,
                    input_key=OUTPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key=criterions,
                    multiplier=1.0,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("\tFP16 mode      :", fp16)
        print("\tFast mode      :", args.fast)
        print("\tTrain mode     :", train_mode)
        print("\tEpochs         :", num_epochs)
        print("\tWorkers        :", num_workers)
        print("\tData dir       :", data_dir)
        print("\tLog dir        :", log_dir)
        print("\tAugmentations  :", augmentations)
        print("\tTrain size     :", len(loaders["train"]), len(train_ds))
        print("\tValid size     :", len(loaders["valid"]), len(valid_ds))
        print("Model            :", model_name)
        print("\tParameters     :", count_parameters(model))
        print("\tImage size     :", image_size)
        print("Optimizer        :", optimizer_name)
        print("\tLearning rate  :", learning_rate)
        print("\tBatch size     :", batch_size)
        print("\tCriterion      :", criterions)
        print("\tUse weight mask:", need_weight_mask)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        unpack_checkpoint(torch.load(model_checkpoint), model=model)

        mask = predict(model,
                       read_inria_image("sample_color.jpg"),
                       image_size=image_size,
                       batch_size=args.batch_size)
        mask = ((mask > 0) * 255).astype(np.uint8)
        name = os.path.join(log_dir, "sample_color.jpg")
        cv2.imwrite(name, mask)

        del optimizer, loaders