Example #1
0
def test_callback_wrapping():
    """
    Test on callback wrapping for GanExperiment class.
    """
    model = torch.nn.Module()
    dataset = torch.utils.data.Dataset()
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict()
    loaders["train"] = dataloader
    # Prepare callbacks and state kwargs
    discriminator_loss_key = "loss_d"
    generator_loss_key = "loss_g"
    discriminator_key = "discriminator"
    generator_key = "generator"
    input_callbacks = OrderedDict(
        {
            "optim_d": OptimizerCallback(
                loss_key=discriminator_loss_key,
                optimizer_key=discriminator_key
            ),
            "optim_g": OptimizerCallback(
                loss_key=generator_loss_key, optimizer_key=generator_key
            ),
            "tensorboard": TensorboardLogger(),
        }
    )
    state_kwargs = {
        "discriminator_train_phase": "discriminator_train",
        "discriminator_train_num": 1,
        "generator_train_phase": "generator_train",
        "generator_train_num": 5,
    }
    discriminator_callbacks = ["optim_d"]
    generator_callbacks = ["optim_g"]
    phase2callbacks = {
        state_kwargs["discriminator_train_phase"]: discriminator_callbacks,
        state_kwargs["generator_train_phase"]: generator_callbacks,
    }

    exp = GanExperiment(
        model=model,
        loaders=loaders,
        callbacks=input_callbacks,
        state_kwargs=state_kwargs,
        phase2callbacks=phase2callbacks,
    )

    callbacks = exp.get_callbacks("train")
    assert "optim_d" in callbacks.keys()
    assert "optim_g" in callbacks.keys()
    assert "tensorboard" in callbacks.keys()
    assert isinstance(callbacks["phase_manager"], PhaseManagerCallback)
    assert isinstance(callbacks["optim_d"], PhaseWrapperCallback)
    assert isinstance(callbacks["optim_g"], PhaseWrapperCallback)
    assert isinstance(callbacks["tensorboard"], TensorboardLogger)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-m", "--model", type=str, default="resnet34", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    freeze_encoder = args.freeze_encoder
    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name: str = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    fine_tune = args.fine_tune
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir
    warmup_batch_size = args.warmup_batch_size or args.batch_size

    # Compute batch size for validation
    valid_batch_size = train_batch_size
    run_train = num_epochs > 0

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if embedding_loss is not None:
        custom_model_kwargs["need_embedding"] = True

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()
    required_features = model.required_features

    if mask_loss is not None:
        required_features.append(INPUT_TRUE_MODIFICATION_MASK)

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transferring weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        from pytorch_toolbelt.optimization.functional import freeze_model

        freeze_model(model, freeze_bn=True)
        print("Freezing bn params")

    main_metric = "loss"
    main_metric_minimize = True

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = []

    if show:
        default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)]

    # Pretrain/warmup
    if warmup:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=0,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            feature_maps_loss=feature_maps_loss,
            num_epochs=warmup,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=warmup_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True)

        if freeze_encoder:
            from pytorch_toolbelt.optimization.functional import freeze_model

            freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None)

        optimizer = get_optimizer(
            "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4
        )
        scheduler = None

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout, "(Non-default)" if dropout is not None else "")
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        # Restore state of best model
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        if negative_image_dir:
            negatives_ds = get_negatives_ds(
                negative_image_dir, fold=fold, features=required_features, max_images=16536
            )
            train_ds = train_ds + negatives_ds
            train_sampler = None  # TODO: Add proper support of sampler
            print("Adding", len(negatives_ds), "negative samples to training set")

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=num_epochs,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

        # Restore state of best model
        clean_checkpoint(best_checkpoint, model_checkpoint)
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if fine_tune:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation="light",
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=fine_tune,
            mixup=False,
            cutmix=False,
            tsa=False,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "finetune"),
            num_epochs=fine_tune,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth")

        clean_checkpoint(best_checkpoint, model_checkpoint)
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        del optimizer, loaders, runner, callbacks
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=None,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    distributed_params = {}
    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}

    if args.fp16:
        distributed_params["apex"] = True
        distributed_params["opt_level"] = "O1"

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    verbose = args.verbose
    show = args.show
    use_dsv = args.dsv
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "optimized_jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    if args.distributed:
        checkpoint_prefix += f"_local_rank_{args.local_rank}"

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"),
        JaccardMetricPerImageWithOptimalThreshold(input_key=INPUT_MASK_KEY,
                                                  output_key=OUTPUT_MASK_KEY,
                                                  prefix="optimized_jaccard")
    ]

    if is_master:
        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="optimized_jaccard",
                                         target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout)
                }),
        ]

    if show and is_master:
        visualize_inria_predictions = partial(
            draw_inria_predictions,
            image_key=INPUT_IMAGE_KEY,
            image_id_key=INPUT_IMAGE_ID_KEY,
            targets_key=INPUT_MASK_KEY,
            outputs_key=OUTPUT_MASK_KEY,
            max_images=16,
        )
        default_callbacks += [
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="accuracy",
                                     minimize=False),
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="loss",
                                     minimize=True),
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            loaders["label"] = DataLoader(unlabeled_label,
                                          batch_size=batch_size // 2,
                                          num_workers=num_workers,
                                          pin_memory=True)

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="label",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(train_sampler,
                                                          args.world_size,
                                                          args.local_rank,
                                                          shuffle=True)
            else:
                train_sampler = DistributedSampler(train_ds,
                                                   args.world_size,
                                                   args.local_rank,
                                                   shuffle=True)
            valid_sampler = DistributedSampler(valid_ds,
                                               args.world_size,
                                               args.local_rank,
                                               shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      sampler=valid_sampler)

        # Create losses
        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix=f"{INPUT_MASK_KEY}/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        if use_dsv:
            print("Using DSV")
            criterions = "dsv"
            dsv_loss_name = "soft_bce"

            criterions_dict[criterions] = AdaptiveMaskLoss2d(
                get_loss(dsv_loss_name, ignore_index=ignore_index))

            for i, dsv_input in enumerate([
                    OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY,
                    OUTPUT_MASK_32_KEY
            ]):
                criterion_callback = CriterionCallback(
                    prefix=f"{dsv_input}/" + dsv_loss_name,
                    input_key=INPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key=criterions,
                    multiplier=1.0,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        if isinstance(model, SupervisedHGSegmentationModel):
            print("Using Hourglass DSV")
            dsv_loss_name = "kl"

            criterions_dict["dsv"] = get_loss(dsv_loss_name,
                                              ignore_index=ignore_index)
            num_supervision_inputs = model.encoder.num_blocks - 1
            dsv_outputs = [
                OUTPUT_MASK_4_KEY + "_after_hg_" + str(i)
                for i in range(num_supervision_inputs)
            ]

            for i, dsv_input in enumerate(dsv_outputs):
                criterion_callback = CriterionCallback(
                    prefix="supervision/" + dsv_input,
                    input_key=INPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key="dsv",
                    multiplier=(i + 1) / num_supervision_inputs,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        callbacks += [
            MetricAggregationCallback(prefix="loss",
                                      metrics=losses,
                                      mode="sum"),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Train mode     :", train_mode)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Augmentations  :", augmentations)
        print("  Train size     :", "batches", len(loaders["train"]),
              "dataset", len(train_ds))
        print("  Valid size     :", "batches", len(loaders["valid"]),
              "dataset", len(valid_ds))
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Image size     :", image_size)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Batch size     :", batch_size)
        print("  Criterion      :", criterions)
        print("  Use weight mask:", need_weight_mask)
        if args.distributed:
            print("Distributed")
            print("  World size     :", args.world_size)
            print("  Local rank     :", args.local_rank)
            print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                                  output_key=None,
                                  device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        unpack_checkpoint(torch.load(model_checkpoint), model=model)

        mask = predict(model,
                       read_inria_image("sample_color.jpg"),
                       image_size=image_size,
                       batch_size=args.batch_size)
        mask = ((mask > 0) * 255).astype(np.uint8)
        name = os.path.join(log_dir, "sample_color.jpg")
        cv2.imwrite(name, mask)

        del optimizer, loaders
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument(
        "--disaster-type-loss",
        type=str,
        default=None,  # [["ce", 1.0]],
        action="append",
        nargs="+",
        help="Criterion for classifying disaster type",
    )
    parser.add_argument(
        "--damage-type-loss",
        type=str,
        default=None,  # [["bce", 1.0]],
        action="append",
        nargs="+",
        help=
        "Criterion for classifying presence of building with particular damage type",
    )

    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("--mask4",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 4")
    parser.add_argument("--mask8",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 8")
    parser.add_argument("--mask16",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 16")
    parser.add_argument("--mask32",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 32")
    parser.add_argument("--embedding", type=str, default=None)

    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="safe",
                        type=str,
                        help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("--fold", default=0, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("-pl", "--pseudolabeling", type=str, required=True)
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--only-buildings", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")
    parser.add_argument("--crops",
                        action="store_true",
                        help="Train on random crops")
    parser.add_argument("--post-transform", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    segmentation_losses = args.criterion
    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    only_buildings = args.only_buildings
    freeze_bn = args.freeze_bn
    train_on_crops = args.crops
    enable_post_image_transform = args.post_transform
    disaster_type_loss = args.disaster_type_loss
    train_batch_size = args.batch_size
    embedding_criterion = args.embedding
    damage_type_loss = args.damage_type_loss
    pseudolabels_dir = args.pseudolabeling

    # Compute batch size for validaion
    if train_on_crops:
        valid_batch_size = max(1,
                               (train_batch_size *
                                (image_size[0] * image_size[1])) // (1024**2))
    else:
        valid_batch_size = train_batch_size

    run_train = num_epochs > 0

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        torch_utils.freeze_bn(model)
        print("Freezing bn params")

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None)
    main_metric = "weighted_f1"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if pseudolabels_dir:
        checkpoint_prefix += "_pseudo"

    if train_on_crops:
        checkpoint_prefix += "_crops"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        CompetitionMetricCallback(input_key=INPUT_MASK_KEY,
                                  output_key=OUTPUT_MASK_KEY,
                                  prefix="weighted_f1"),
        ConfusionMatrixCallback(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            class_names=[
                "land", "no_damage", "minor_damage", "major_damage",
                "destroyed"
            ],
            ignore_index=UNLABELED_SAMPLE,
        ),
    ]

    if show:
        default_callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric=main_metric + "_batch",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        fast=fast,
        fold=fold,
        balance=balance,
        only_buildings=only_buildings,
        train_on_crops=train_on_crops,
        crops_multiplication_factor=1,
        enable_post_image_transform=enable_post_image_transform,
    )

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        unlabeled_train = get_pseudolabeling_dataset(
            data_dir,
            include_masks=True,
            image_size=image_size,
            augmentation="medium_nmd",
            train_on_crops=train_on_crops,
            enable_post_image_transform=enable_post_image_transform,
            pseudolabels_dir=pseudolabels_dir,
        )

        train_ds = train_ds + unlabeled_train

        print("Using online pseudolabeling with ", len(unlabeled_train),
              "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=valid_batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for criterion in segmentation_losses:
            if isinstance(criterion, (list, tuple)) and len(criterion) == 2:
                loss_name, loss_weight = criterion
            else:
                loss_name, loss_weight = criterion[0], 1.0

            cd, criterion, criterion_name = get_criterion_callback(
                loss_name,
                prefix="segmentation",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_MASK_KEY,
                loss_weight=float(loss_weight),
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight)

        if args.mask4 is not None:
            for criterion in args.mask4:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask4",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_4_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight)

        if args.mask8 is not None:
            for criterion in args.mask8:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask8",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_8_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight)

        if args.mask16 is not None:
            for criterion in args.mask16:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask16",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_16_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight)

        if args.mask32 is not None:
            for criterion in args.mask32:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask32",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_32_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight)

        if disaster_type_loss is not None:
            callbacks += [
                ConfusionMatrixCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    class_names=DISASTER_TYPES,
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                    prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix",
                ),
                AccuracyCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    prefix=f"{DISASTER_TYPE_KEY}/accuracy",
                    activation="Softmax",
                ),
            ]

            for criterion in disaster_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DISASTER_TYPE_KEY,
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    loss_weight=float(loss_weight),
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if damage_type_loss is not None:
            callbacks += [
                # MultilabelConfusionMatrixCallback(
                #     input_key=DAMAGE_TYPE_KEY,
                #     output_key=DAMAGE_TYPE_KEY,
                #     class_names=DAMAGE_TYPES,
                #     prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix",
                # ),
                AccuracyCallback(
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    prefix=f"{DAMAGE_TYPE_KEY}/accuracy",
                    activation="Sigmoid",
                    threshold=0.5,
                )
            ]

            for criterion in damage_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DAMAGE_TYPE_KEY,
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if embedding_criterion is not None:
            cd, criterion, criterion_name = get_criterion_callback(
                embedding_criterion,
                prefix="embedding",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_EMBEDDING_KEY,
                loss_weight=1.0,
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("Data             ")
        print("  Augmentations  :", augmentations)
        print("  Train size     :", len(loaders["train"]), len(train_ds))
        print("  Valid size     :", len(loaders["valid"]), len(valid_ds))
        print("  Image size     :", image_size)
        print("  Train on crops :", train_on_crops)
        print("  Balance        :", balance)
        print("  Buildings only :", only_buildings)
        print("  Post transform :", enable_post_image_transform)
        print("  Pseudolabels   :", pseudolabels_dir)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("  Criterion      :", segmentation_losses)
        print("  Damage type    :", damage_type_loss)
        print("  Disaster type  :", disaster_type_loss)
        print(" Embedding      :", embedding_criterion)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "opl"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        del optimizer, loaders
Example #5
0
def run_stage_training(model: Union[TimmRgbModel,
                                    YCrCbModel], config: StageConfig,
                       exp_config: ExperimenetConfig, experiment_dir: str):
    # Preparing model
    freeze_model(model, freeze_bn=config.freeze_bn)

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=exp_config.data_dir,
        image_size=config.image_size,
        augmentation=config.augmentations,
        balance=config.balance,
        fast=config.fast,
        fold=exp_config.fold,
        features=model.required_features,
        obliterate_p=config.obliterate_p,
    )

    criterions_dict, loss_callbacks = get_criterions(
        modification_flag=config.modification_flag_loss,
        modification_type=config.modification_type_loss,
        embedding_loss=config.embedding_loss,
        feature_maps_loss=config.feature_maps_loss,
        num_epochs=config.epochs,
        mixup=config.mixup,
        cutmix=config.cutmix,
        tsa=config.tsa,
    )

    callbacks = loss_callbacks + [
        OptimizerCallback(accumulation_steps=config.accumulation_steps,
                          decouple_weight_decay=False),
        HyperParametersCallback(
            hparam_dict={
                "model": exp_config.model_name,
                "scheduler": config.schedule,
                "optimizer": config.optimizer,
                "augmentations": config.augmentations,
                "size": config.image_size[0],
                "weight_decay": config.weight_decay,
            }),
    ]

    if config.show:
        callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric="loss",
                                     minimize=True)
        ]

    loaders = collections.OrderedDict()
    loaders["train"] = DataLoader(
        train_ds,
        batch_size=config.train_batch_size,
        num_workers=exp_config.num_workers,
        pin_memory=True,
        drop_last=True,
        shuffle=train_sampler is None,
        sampler=train_sampler,
    )

    loaders["valid"] = DataLoader(valid_ds,
                                  batch_size=config.valid_batch_size,
                                  num_workers=exp_config.num_workers,
                                  pin_memory=True)

    print("Stage            :", config.stage_name)
    print("  FP16 mode      :", config.fp16)
    print("  Fast mode      :", config.fast)
    print("  Epochs         :", config.epochs)
    print("  Workers        :", exp_config.num_workers)
    print("  Data dir       :", exp_config.data_dir)
    print("  Experiment dir :", experiment_dir)
    print("Data              ")
    print("  Augmentations  :", config.augmentations)
    print("  Obliterate (%) :", config.obliterate_p)
    print("  Negative images:", config.negative_image_dir)
    print("  Train size     :", len(loaders["train"]), "batches",
          len(train_ds), "samples")
    print("  Valid size     :", len(loaders["valid"]), "batches",
          len(valid_ds), "samples")
    print("  Image size     :", config.image_size)
    print("  Balance        :", config.balance)
    print("  Mixup          :", config.mixup)
    print("  CutMix         :", config.cutmix)
    print("  TSA            :", config.tsa)
    print("Model            :", exp_config.model_name)
    print("  Parameters     :", count_parameters(model))
    print("  Dropout        :", exp_config.dropout)
    print("Optimizer        :", config.optimizer)
    print("  Learning rate  :", config.learning_rate)
    print("  Weight decay   :", config.weight_decay)
    print("  Scheduler      :", config.schedule)
    print("  Batch sizes    :", config.train_batch_size,
          config.valid_batch_size)
    print("Losses            ")
    print("  Flag           :", config.modification_flag_loss)
    print("  Type           :", config.modification_type_loss)
    print("  Embedding      :", config.embedding_loss)
    print("  Feature maps   :", config.feature_maps_loss)

    optimizer = get_optimizer(
        config.optimizer,
        get_optimizable_parameters(model),
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    scheduler = get_scheduler(
        config.schedule,
        optimizer,
        lr=config.learning_rate,
        num_epochs=config.epochs,
        batches_in_epoch=len(loaders["train"]),
    )
    if isinstance(scheduler, CyclicLR):
        callbacks += [SchedulerCallback(mode="batch")]

    # model training
    runner = SupervisedRunner(input_key=model.required_features,
                              output_key=None)
    runner.train(
        fp16=config.fp16,
        model=model,
        criterion=criterions_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=os.path.join(experiment_dir, config.stage_name),
        num_epochs=config.epochs,
        verbose=config.verbose,
        main_metric=config.main_metric,
        minimize_metric=config.main_metric_minimize,
        checkpoint_data={"config": config},
    )

    del optimizer, loaders, callbacks, runner

    best_checkpoint = os.path.join(experiment_dir, config.stage_name,
                                   "checkpoints", "best.pth")
    model_checkpoint = os.path.join(experiment_dir,
                                    f"{exp_config.checkpoint_prefix}.pth")
    clean_checkpoint(best_checkpoint, model_checkpoint)

    # Restore state of best model
    if config.restore_best:
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

    # Some memory cleanup
    torch.cuda.empty_cache()
    gc.collect()
Example #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')

    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')

    parser.add_argument('-l',
                        '--loss',
                        type=str,
                        default='label_smooth_cross_entropy')
    parser.add_argument('-t1', '--temper1', type=float, default=0.2)
    parser.add_argument('-t2', '--temper2', type=float, default=4.0)
    parser.add_argument('-optim', '--optimizer', type=str, default='adam')

    parser.add_argument('-prep', '--prep_function', type=str, default='none')

    parser.add_argument('--train_on_different_datasets', action='store_true')
    parser.add_argument('--use-current', action='store_true')
    parser.add_argument('--use-extra', action='store_true')
    parser.add_argument('--use-unlabeled', action='store_true')

    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')

    parser.add_argument('--show', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')

    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='efficientnet-b4',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-s',
                        '--sizes',
                        default=380,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-f', '--fold', type=int, default=None)
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('-lr',
                        '--learning_rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-accum', '--accum-step', type=int, default=1)
    parser.add_argument('-metric', '--metric', type=str, default='accuracy01')

    args = parser.parse_args()

    diff_dataset_train = args.train_on_different_datasets

    data_dir = args.data_dir
    epochs = args.epochs
    batch_size = args.batch_size
    seed = args.seed

    loss_name = args.loss
    optim_name = args.optimizer

    prep_function = args.prep_function

    model_name = args.model
    size = args.sizes,
    print(size)
    print(size[0])
    image_size = (size[0], size[0])
    print(image_size)
    fast = args.fast
    fold = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    show_batches = args.show
    verbose = args.verbose
    use_current = args.use_current
    use_extra = args.use_extra
    use_unlabeled = args.use_unlabeled

    learning_rate = args.learning_rate
    augmentations = args.augmentations
    transfer = args.transfer
    accum_step = args.accum_step

    #cosine_loss    accuracy01
    main_metric = args.metric

    print(data_dir)

    num_classes = 5

    assert use_current or use_extra

    print(fold)

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    # if folds is None or len(folds) == 0:
    #     folds = [None]

    torch.cuda.empty_cache()
    checkpoint_prefix = f'{model_name}_{size}_{augmentations}'

    if transfer is not None:
        checkpoint_prefix += '_pretrain_from_' + str(transfer)
    else:
        if use_current:
            checkpoint_prefix += '_current'
        if use_extra:
            checkpoint_prefix += '_extra'
        if use_unlabeled:
            checkpoint_prefix += '_unlabeled'
        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

    directory_prefix = f'{current_time}_{checkpoint_prefix}'
    log_dir = os.path.join('runs', directory_prefix)
    os.makedirs(log_dir, exist_ok=False)

    set_manual_seed(seed)
    model = get_model(model_name)

    if transfer is not None:
        print("Transfering weights from model checkpoint")
        model.load_state_dict(torch.load(transfer)['model_state_dict'])

    model = model.cuda()

    if diff_dataset_train:
        train_on = ['current_train', 'extra_train']
        valid_on = ['unlabeled']
        train_ds, valid_ds, train_sizes = get_datasets_universal(
            train_on=train_on,
            valid_on=valid_on,
            image_size=image_size,
            augmentation=augmentations,
            target_dtype=int,
            prep_function=prep_function)
    else:
        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_current=use_current,
            use_extra=use_extra,
            image_size=image_size,
            prep_function=prep_function,
            augmentation=augmentations,
            target_dtype=int,
            fold=fold,
            folds=5)

    train_loader, valid_loader = get_dataloaders(train_ds,
                                                 valid_ds,
                                                 batch_size=batch_size,
                                                 train_sizes=train_sizes,
                                                 num_workers=6,
                                                 balance=True,
                                                 balance_datasets=True,
                                                 balance_unlabeled=False)

    loaders = collections.OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    runner = SupervisedRunner(input_key='image')

    criterions = get_loss(loss_name)
    # criterions_tempered = TemperedLogLoss()
    # optimizer = catalyst.contrib.nn.optimizers.radam.RAdam(model.parameters(), lr = learning_rate)
    optimizer = get_optim(optim_name, model, learning_rate)
    # optimizer = catalyst.contrib.nn.optimizers.Adam(model.parameters(), lr = learning_rate)
    # criterions = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25], gamma=0.8)
    # cappa = CappaScoreCallback()

    Q = math.floor(len(train_ds) / batch_size)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Q)
    if main_metric != 'accuracy01':
        callbacks = [
            AccuracyCallback(num_classes=num_classes),
            CosineLossCallback(),
            OptimizerCallback(accumulation_steps=accum_step),
            CheckpointCallback(save_n_best=epochs)
        ]
    else:
        callbacks = [
            AccuracyCallback(num_classes=num_classes),
            OptimizerCallback(accumulation_steps=accum_step),
            CheckpointCallback(save_n_best=epochs)
        ]

    # main_metric = 'accuracy01'

    runner.train(
        fp16=True,
        model=model,
        criterion=criterions,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=log_dir,
        num_epochs=epochs,
        verbose=verbose,
        main_metric=main_metric,
        minimize_metric=False,
    )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    use_dsv = args.dsv
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                              output_key=None,
                              device="cuda")
    main_metric = "optimized_jaccard"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        JaccardMetricPerImage(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY,
                              prefix="jaccard"),
        OptimalThreshold(input_key=INPUT_MASK_KEY,
                         output_key=OUTPUT_MASK_KEY,
                         prefix="optimized_jaccard"),
        # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid),
    ]

    if show:
        visualize_inria_predictions = partial(
            draw_inria_predictions,
            image_key=INPUT_IMAGE_KEY,
            image_id_key=INPUT_IMAGE_ID_KEY,
            targets_key=INPUT_MASK_KEY,
            outputs_key=OUTPUT_MASK_KEY,
        )
        default_callbacks += [
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="accuracy",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    # Pretrain/warmup
    if warmup:
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []
        ignore_index = None

        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        parameters = get_lr_decay_parameters(model.named_parameters(),
                                             learning_rate, {"encoder": 0.1})
        optimizer = get_optimizer("RAdam",
                                  parameters,
                                  learning_rate=learning_rate * 0.1)

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=None,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        del optimizer, loaders

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                       "best.pth")
        model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                        f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            loaders["label"] = DataLoader(unlabeled_label,
                                          batch_size=batch_size // 2,
                                          num_workers=num_workers,
                                          pin_memory=True)

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="label",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        if use_dsv:
            print("Using DSV")
            criterions = "dsv"
            dsv_loss_name = "soft_bce"

            criterions_dict[criterions] = AdaptiveMaskLoss2d(
                get_loss(dsv_loss_name, ignore_index=ignore_index))

            for i, dsv_input in enumerate([
                    OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY,
                    OUTPUT_MASK_32_KEY
            ]):
                criterion_callback = CriterionCallback(
                    prefix="seg_loss_dsv/" + dsv_input,
                    input_key=OUTPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key=criterions,
                    multiplier=1.0,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("\tFP16 mode      :", fp16)
        print("\tFast mode      :", args.fast)
        print("\tTrain mode     :", train_mode)
        print("\tEpochs         :", num_epochs)
        print("\tWorkers        :", num_workers)
        print("\tData dir       :", data_dir)
        print("\tLog dir        :", log_dir)
        print("\tAugmentations  :", augmentations)
        print("\tTrain size     :", len(loaders["train"]), len(train_ds))
        print("\tValid size     :", len(loaders["valid"]), len(valid_ds))
        print("Model            :", model_name)
        print("\tParameters     :", count_parameters(model))
        print("\tImage size     :", image_size)
        print("Optimizer        :", optimizer_name)
        print("\tLearning rate  :", learning_rate)
        print("\tBatch size     :", batch_size)
        print("\tCriterion      :", criterions)
        print("\tUse weight mask:", need_weight_mask)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        unpack_checkpoint(torch.load(model_checkpoint), model=model)

        mask = predict(model,
                       read_inria_image("sample_color.jpg"),
                       image_size=image_size,
                       batch_size=args.batch_size)
        mask = ((mask > 0) * 255).astype(np.uint8)
        name = os.path.join(log_dir, "sample_color.jpg")
        cv2.imwrite(name, mask)

        del optimizer, loaders
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=0, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    optimizer_name = args.optimizer
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir

    # Compute batch size for validation
    valid_batch_size = train_batch_size

    current_time = datetime.now().strftime("%b%d_%H_%M")

    main_metric = "loss"
    main_metric_minimize = True

    x_train = np.load(f"embeddings_x_train_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")
    y_train = np.load(f"embeddings_y_train_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")

    x_valid = np.load(f"embeddings_x_holdout_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")
    y_valid = np.load(f"embeddings_y_holdout_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")

    print(x_train.shape, x_valid.shape)
    print(np.bincount(y_train), np.bincount(y_valid))

    train_ds = StackerDataset(x_train, y_train)
    valid_ds = StackerDataset(x_valid, y_valid)

    criterions_dict, loss_callbacks = get_criterions(
        modification_flag=modification_flag_loss,
        modification_type=modification_type_loss,
        embedding_loss=None,
        feature_maps_loss=None,
        mask_loss=None,
        bits_loss=None,
        num_epochs=num_epochs,
        mixup=mixup,
        cutmix=None,
        tsa=tsa,
    )

    callbacks = loss_callbacks + [
        OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
        HyperParametersCallback(
            hparam_dict={
                "scheduler": scheduler_name,
                "optimizer": optimizer_name,
                "augmentations": augmentations,
                "weight_decay": weight_decay,
            }
        ),
    ]

    loaders = collections.OrderedDict()
    loaders["train"] = DataLoader(
        train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=True
    )

    loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

    model: nn.Module = StackingModel(x_train.shape[1], dropout=dropout).cuda()

    optimizer = get_optimizer(
        optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
    )
    scheduler = get_scheduler(
        scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
    )
    if isinstance(scheduler, CyclicLR):
        callbacks += [SchedulerCallback(mode="batch")]

    checkpoint_prefix = f"{current_time}_stacking"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    print("Train session    :", checkpoint_prefix)
    print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
    print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
    print("  FP16 mode      :", fp16)
    print("  Fast mode      :", args.fast)
    print("  Epochs         :", num_epochs)
    print("  Workers        :", num_workers)
    print("  Data dir       :", data_dir)
    print("  Log dir        :", log_dir)
    print("  Cache          :", cache)
    print("Data              ")
    print("  Augmentations  :", augmentations)
    print("  Obliterate (%) :", obliterate_p)
    print("  Negative images:", negative_image_dir)
    print("  Balance        :", balance)
    print("  Mixup          :", mixup)
    print("  CutMix         :", cutmix)
    print("  TSA            :", tsa)
    # print("Model            :", model_name)
    print("  Parameters     :", count_parameters(model))
    print("  Dropout        :", dropout)
    print("Optimizer        :", optimizer_name)
    print("  Learning rate  :", learning_rate)
    print("  Weight decay   :", weight_decay)
    print("  Scheduler      :", scheduler_name)
    print("  Batch sizes    :", train_batch_size, valid_batch_size)
    print("Losses            ")
    print("  Flag           :", modification_flag_loss)
    print("  Type           :", modification_type_loss)
    print("  Embedding      :", embedding_loss)
    print("  Feature maps   :", feature_maps_loss)
    print("  Mask           :", mask_loss)
    print("  Bits           :", bits_loss)

    # model training
    runner = SupervisedRunner(input_key=[INPUT_EMBEDDING_KEY], output_key=None)
    runner.train(
        fp16=fp16,
        model=model,
        criterion=criterions_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=os.path.join(log_dir, "main"),
        num_epochs=num_epochs,
        verbose=verbose,
        main_metric=main_metric,
        minimize_metric=main_metric_minimize,
        checkpoint_data={"cmd_args": vars(args)},
    )

    del optimizer, loaders, runner, callbacks

    best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
    model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

    # Restore state of best model
    clean_checkpoint(best_checkpoint, model_checkpoint)
Example #9
0
def main(cfg: DictConfig):

    cwd = Path(get_original_cwd())

    # overwrite config if continue training from checkpoint
    resume_cfg = None
    if "resume" in cfg:
        cfg_path = cwd / cfg.resume / ".hydra/config.yaml"
        print(f"Continue from: {cfg.resume}")
        # Overwrite everything except device
        # TODO config merger (perhaps continue training with the same optimizer but other lrs?)
        resume_cfg = OmegaConf.load(cfg_path)
        cfg.model = resume_cfg.model
        if cfg.train.num_epochs == 0:
            cfg.data.scale_factor = resume_cfg.data.scale_factor
        OmegaConf.save(cfg, ".hydra/config.yaml")

    print(OmegaConf.to_yaml(cfg))

    device = set_device_id(cfg.device)
    set_seed(cfg.seed, device=device)

    # Augmentations
    if cfg.data.aug == "auto":
        transforms = albu.load(cwd / "autoalbument/autoconfig.json")
    else:
        transforms = D.get_training_augmentations()

    if OmegaConf.is_missing(cfg.model, "convert_bottleneck"):
        cfg.model.convert_bottleneck = (0, 0, 0)

    # Model
    print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} "
          f"convert_bn={cfg.model.convert_bn} "
          f"convert_bottleneck={cfg.model.convert_bottleneck} ")
    model = get_segmentation_model(
        arch=cfg.model.arch,
        encoder_name=cfg.model.encoder_name,
        encoder_weights=cfg.model.encoder_weights,
        classes=1,
        convert_bn=cfg.model.convert_bn,
        convert_bottleneck=cfg.model.convert_bottleneck,
        # decoder_attention_type="scse",  # TODO to config
    )
    model = model.to(device)
    model.train()
    print(model)

    # Optimization
    # Reduce LR for pretrained encoder
    layerwise_params = {
        "encoder*":
        dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder)
    }
    model_params = cutils.process_model_params(
        model, layerwise_params=layerwise_params)

    # Select optimizer
    optimizer = get_optimizer(
        name=cfg.optim.name,
        model_params=model_params,
        lr=cfg.optim.lr,
        wd=cfg.optim.wd,
        lookahead=cfg.optim.lookahead,
    )

    criterion = {
        "dice": DiceLoss(),
        # "dice": SoftDiceLoss(mode="binary", smooth=1e-7),
        "iou": IoULoss(),
        "bce": nn.BCEWithLogitsLoss(),
        "lovasz": LovaszLossBinary(),
        "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75),
    }

    # Load states if resuming training
    if "resume" in cfg:
        checkpoint_path = (cwd / cfg.resume / cfg.train.logdir /
                           "checkpoints/best_full.pth")
        if checkpoint_path.exists():
            print(f"\nLoading checkpoint {str(checkpoint_path)}")
            checkpoint = cutils.load_checkpoint(checkpoint_path)
            cutils.unpack_checkpoint(
                checkpoint=checkpoint,
                model=model,
                optimizer=optimizer
                if resume_cfg.optim.name == cfg.optim.name else None,
                criterion=criterion,
            )
        else:
            raise ValueError("Nothing to resume, checkpoint missing")

    # We could only want to validate resume, in this case skip training routine
    best_th = 0.5

    stats = None
    if cfg.data.stats:
        print(f"Use statistics from file: {cfg.data.stats}")
        stats = cwd / cfg.data.stats

    if cfg.train.num_epochs is not None:
        callbacks = [
            # Each criterion is calculated separately.
            CriterionCallback(input_key="mask",
                              prefix="loss_dice",
                              criterion_key="dice"),
            CriterionCallback(input_key="mask",
                              prefix="loss_iou",
                              criterion_key="iou"),
            CriterionCallback(input_key="mask",
                              prefix="loss_bce",
                              criterion_key="bce"),
            CriterionCallback(input_key="mask",
                              prefix="loss_lovasz",
                              criterion_key="lovasz"),
            CriterionCallback(
                input_key="mask",
                prefix="loss_focal_tversky",
                criterion_key="focal_tversky",
            ),
            # And only then we aggregate everything into one loss.
            MetricAggregationCallback(
                prefix="loss",
                mode="weighted_sum",  # can be "sum", "weighted_sum" or "mean"
                # because we want weighted sum, we need to add scale for each loss
                metrics={
                    "loss_dice": cfg.loss.dice,
                    "loss_iou": cfg.loss.iou,
                    "loss_bce": cfg.loss.bce,
                    "loss_lovasz": cfg.loss.lovasz,
                    "loss_focal_tversky": cfg.loss.focal_tversky,
                },
            ),
            # metrics
            DiceCallback(input_key="mask"),
            IouCallback(input_key="mask"),
            # gradient accumulation
            OptimizerCallback(accumulation_steps=cfg.optim.accumulate),
            # early stopping
            SchedulerCallback(reduced_metric="loss_dice",
                              mode=cfg.scheduler.mode),
            EarlyStoppingCallback(**cfg.scheduler.early_stopping,
                                  minimize=False),
            # TODO WandbLogger works poorly with multistage right now
            WandbLogger(project=cfg.project, config=dict(cfg)),
            # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best),
        ]

        # Training
        runner = SupervisedRunner(device=device,
                                  input_key="image",
                                  input_target_key="mask")

        # TODO Scheduler does not work now, every stage restarts from base lr
        scheduler_warm_restart = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[1, 2],
            gamma=10,
        )

        for i, (size, num_epochs) in enumerate(
                zip(cfg.data.sizes, cfg.train.num_epochs)):
            scale = size / 1024
            print(
                f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}"
            )

            # Datasets
            (
                train_ds,
                valid_ds,
                train_images,
                val_images,
            ) = D.get_train_valid_datasets_from_path(
                # path=(cwd / cfg.data.path),
                path=(cwd / f"data/hubmap-{size}x{size}/"),
                train_ids=cfg.data.train_ids,
                valid_ids=cfg.data.valid_ids,
                seed=cfg.seed,
                valid_split=cfg.data.valid_split,
                mean=cfg.data.mean,
                std=cfg.data.std,
                transforms=transforms,
                stats=stats,
            )

            train_bs = int(cfg.loader.train_bs / (scale**2))
            valid_bs = int(cfg.loader.valid_bs / (scale**2))
            print(
                f"train: {len(train_ds)}; bs {train_bs}",
                f"valid: {len(valid_ds)}, bs {valid_bs}",
            )

            # Data loaders
            data_loaders = D.get_data_loaders(
                train_ds=train_ds,
                valid_ds=valid_ds,
                train_bs=train_bs,
                valid_bs=valid_bs,
                num_workers=cfg.loader.num_workers,
            )

            # Select scheduler
            scheduler = get_scheduler(
                name=cfg.scheduler.type,
                optimizer=optimizer,
                num_epochs=num_epochs * (len(data_loaders["train"]) if
                                         cfg.scheduler.mode == "batch" else 1),
                eta_min=scheduler_warm_restart.get_last_lr()[0] /
                cfg.scheduler.eta_min_factor,
                plateau=cfg.scheduler.plateau,
            )

            runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                callbacks=callbacks,
                logdir=cfg.train.logdir,
                loaders=data_loaders,
                num_epochs=num_epochs,
                verbose=True,
                main_metric=cfg.train.main_metric,
                load_best_on_end=True,
                minimize_metric=False,
                check=cfg.check,
                fp16=dict(amp=cfg.amp),
            )

            # Set new initial LR for optimizer after restart
            scheduler_warm_restart.step()
            print(
                f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}"
            )

            # Find optimal threshold for dice score
            model.eval()
            best_th, dices = find_dice_threshold(model, data_loaders["valid"])
            print("Best dice threshold", best_th, np.max(dices[1]))
            np.save(f"dices_{size}.npy", dices)
    else:
        print("Validation only")
        # Datasets
        size = cfg.data.sizes[-1]
        train_ds, valid_ds = D.get_train_valid_datasets_from_path(
            # path=(cwd / cfg.data.path),
            path=(cwd / f"data/hubmap-{size}x{size}/"),
            train_ids=cfg.data.train_ids,
            valid_ids=cfg.data.valid_ids,
            seed=cfg.seed,
            valid_split=cfg.data.valid_split,
            mean=cfg.data.mean,
            std=cfg.data.std,
            transforms=transforms,
            stats=stats,
        )

        train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2))
        valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2))
        print(
            f"train: {len(train_ds)}; bs {train_bs}",
            f"valid: {len(valid_ds)}, bs {valid_bs}",
        )

        # Data loaders
        data_loaders = D.get_data_loaders(
            train_ds=train_ds,
            valid_ds=valid_ds,
            train_bs=train_bs,
            valid_bs=valid_bs,
            num_workers=cfg.loader.num_workers,
        )

        # Find optimal threshold for dice score
        model.eval()
        best_th, dices = find_dice_threshold(model, data_loaders["valid"])
        print("Best dice threshold", best_th, np.max(dices[1]))
        np.save(f"dices_val.npy", dices)

    #
    # # Load best checkpoint
    # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth"
    # if checkpoint_path.exists():
    #     print(f"\nLoading checkpoint {str(checkpoint_path)}")
    #     state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[
    #         "model_state_dict"
    #     ]
    #     model.load_state_dict(state_dict)
    #     del state_dict
    # model = model.to(device)
    # Load config for updating with threshold and metric
    # (otherwise loading do not work)
    cfg = OmegaConf.load(".hydra/config.yaml")
    cfg.threshold = float(best_th)

    # Evaluate on full-size image if valid_ids is non-empty
    df_train = pd.read_csv(cwd / "data/train.csv")
    df_train = {
        r["id"]: r["encoding"]
        for r in df_train.to_dict(orient="record")
    }
    dices = []
    unique_ids = sorted(
        set(
            str(p).split("/")[-1].split("_")[0]
            for p in (cwd / cfg.data.path / "train").iterdir()))
    size = cfg.data.sizes[-1]
    scale = size / 1024
    for image_id in cfg.data.valid_ids:
        image_name = unique_ids[image_id]
        print(f"\nValidate for {image_name}")

        rle_pred, shape = inference_one(
            image_path=(cwd / f"data/train/{image_name}.tiff"),
            target_path=Path("."),
            cfg=cfg,
            model=model,
            scale_factor=scale,
            tile_size=cfg.data.tile_size,
            tile_step=cfg.data.tile_step,
            threshold=best_th,
            save_raw=False,
            tta_mode=None,
            weight="pyramid",
            device=device,
            filter_crops="tissue",
            stats=stats,
        )

        print("Predict", shape)
        pred = rle_decode(rle_pred["predicted"], shape)
        mask = rle_decode(df_train[image_name], shape)
        assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}"
        assert pred.shape == shape, f"pred {pred.shape}, expected {shape}"

        dices.append(
            dice(
                torch.from_numpy(pred).type(torch.uint8),
                torch.from_numpy(mask).type(torch.uint8),
                threshold=None,
                activation="none",
            ))
    print("Full image dice:", np.mean(dices))
    OmegaConf.save(cfg, ".hydra/config.yaml")
    return